{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0adab93-e569-4af0-80f1-ce5b7a116507",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f583520-3c49-4e79-84ae-02bfc57f1e49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating a set of classes to simplify LLM use\n",
    "\n",
    "from abc import ABC, abstractmethod\n",
    "from dotenv import load_dotenv\n",
    "# Imports for type definition\n",
    "from collections.abc import MutableSequence\n",
    "from typing import TypedDict\n",
    "\n",
    "class LLM_Wrapper(ABC):\n",
    "    \"\"\"\n",
    "    The parent (abstract) class to specific LLM classes, normalising and providing common \n",
    "    and simplified ways to call LLMs while adding some level of abstraction on\n",
    "    specifics\n",
    "    \"\"\"\n",
    "\n",
    "    MessageEntry = TypedDict('MessageEntry', {'role': str, 'content': str})\n",
    "    \n",
    "    system_prompt: str                   # The system prompt used for the LLM\n",
    "    user_prompt: str                     # The user prompt\n",
    "    __api_key: str                       # The (private) api key\n",
    "    temperature: float = 0.5             # Default temperature\n",
    "    __msg: MutableSequence[MessageEntry] # Message builder\n",
    "\n",
    "    def __init__(self, system_prompt:str, user_prompt:str, env_apikey_var:str=None):\n",
    "        \"\"\"\n",
    "        env_apikey_var: str # The name of the env variable where to find the api_key\n",
    "                            # We store the retrieved api_key for future calls\n",
    "        \"\"\"\n",
    "        self.system_prompt = system_prompt\n",
    "        self.user_prompt = user_prompt\n",
    "        if env_apikey_var:\n",
    "            load_dotenv(override=True)\n",
    "            self.__api_key = os.getenv(env_apikey_var)\n",
    "\n",
    "        # # API Key format check\n",
    "        # if env_apikey_var and self.__api_key:\n",
    "        #     print(f\"API Key exists and begins {self.__api_key[:8]}\")\n",
    "        # else:\n",
    "        #     print(\"API Key not set\")\n",
    "            \n",
    "    def setSystemPrompt(self, prompt:str):\n",
    "        self.system_prompt = prompt\n",
    "\n",
    "    def setUserPrompt(self, prompt:str):\n",
    "        self.user_prompt = prompt\n",
    "\n",
    "    def setTemperature(self, temp:float):\n",
    "        self.temperature = temp\n",
    "\n",
    "    def getKey(self) -> str:\n",
    "        return self.__api_key\n",
    "\n",
    "    def messageSet(self, message: MutableSequence[MessageEntry]):\n",
    "        self.__msg = message\n",
    "\n",
    "    def messageAppend(self, role: str, content: str):\n",
    "        self.__msg.append(\n",
    "            {\"role\": role, \"content\": content}\n",
    "        )\n",
    "\n",
    "    def messageGet(self) -> MutableSequence[MessageEntry]:\n",
    "        return self.__msg\n",
    "        \n",
    "    @abstractmethod\n",
    "    def getResult(self):\n",
    "        pass\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a707f3ef-8696-44a9-943e-cfbce24b9fde",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import OpenAI\n",
    "\n",
    "class GPT_Wrapper(LLM_Wrapper):\n",
    "\n",
    "    MODEL:str = 'gpt-4o-mini'\n",
    "    llm:OpenAI\n",
    "\n",
    "    def __init__(self, system_prompt:str, user_prompt:str):\n",
    "        super().__init__(system_prompt, user_prompt, \"OPENAI_API_KEY\")\n",
    "        self.llm = OpenAI()\n",
    "        super().messageSet([\n",
    "            {\"role\": \"system\", \"content\": self.system_prompt},\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "\n",
    "    def setSystemPrompt(self, prompt:str):\n",
    "        super().setSystemPrompt(prompt)\n",
    "        super().messageSet([\n",
    "            {\"role\": \"system\", \"content\": self.system_prompt},\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "    def setUserPrompt(self, prompt:str):\n",
    "        super().setUserPrompt(prompt)\n",
    "        super().messageSet([\n",
    "            {\"role\": \"system\", \"content\": self.system_prompt},\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "    def getResult(self, format=None):\n",
    "        \"\"\"\n",
    "        format is sent as an adittional parameter {\"type\", format}\n",
    "        e.g. json_object\n",
    "        \"\"\"\n",
    "        if format:\n",
    "            response = self.llm.chat.completions.create(\n",
    "                model=self.MODEL,\n",
    "                messages=super().messageGet(),\n",
    "                temperature=self.temperature,\n",
    "                response_format={\"type\": \"json_object\"}\n",
    "            )\n",
    "            if format == \"json_object\":\n",
    "                result = json.loads(response.choices[0].message.content)\n",
    "            else:\n",
    "                result = response.choices[0].message.content\n",
    "        else:\n",
    "            response = self.llm.chat.completions.create(\n",
    "                model=self.MODEL,\n",
    "                messages=super().messageGet(),\n",
    "                temperature=self.temperature\n",
    "            )\n",
    "            result = response.choices[0].message.content\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8529004-0d6a-480c-9634-7d51498255fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ollama\n",
    "\n",
    "class Ollama_Wrapper(LLM_Wrapper):\n",
    "\n",
    "    MODEL:str = 'llama3.2'\n",
    "\n",
    "    def __init__(self, system_prompt:str, user_prompt:str):\n",
    "        super().__init__(system_prompt, user_prompt, None)\n",
    "        self.llm=ollama\n",
    "        super().messageSet([\n",
    "            {\"role\": \"system\", \"content\": self.system_prompt},\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "\n",
    "    def setSystemPrompt(self, prompt:str):\n",
    "        super().setSystemPrompt(prompt)\n",
    "        super().messageSet([\n",
    "            {\"role\": \"system\", \"content\": self.system_prompt},\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "    def setUserPrompt(self, prompt:str):\n",
    "        super().setUserPrompt(prompt)\n",
    "        super().messageSet([\n",
    "            {\"role\": \"system\", \"content\": self.system_prompt},\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "    def getResult(self, format=None):\n",
    "        \"\"\"\n",
    "        format is sent as an adittional parameter {\"type\", format}\n",
    "        e.g. json_object\n",
    "        \"\"\"\n",
    "        response = self.llm.chat(\n",
    "            model=self.MODEL, \n",
    "            messages=super().messageGet()\n",
    "        )\n",
    "        result = response['message']['content']\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25ffb7e-0132-46cb-ad5b-18a300a7eb51",
   "metadata": {},
   "outputs": [],
   "source": [
    "import anthropic\n",
    "\n",
    "class Claude_Wrapper(LLM_Wrapper):\n",
    "\n",
    "    MODEL:str = 'claude-3-5-haiku-20241022'\n",
    "    MAX_TOKENS:int = 200\n",
    "    llm:anthropic.Anthropic\n",
    "\n",
    "    def __init__(self, system_prompt:str, user_prompt:str):\n",
    "        super().__init__(system_prompt, user_prompt, \"ANTHROPIC_API_KEY\")\n",
    "        self.llm = anthropic.Anthropic()\n",
    "        super().messageSet([\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "    def setSystemPrompt(self, prompt:str):\n",
    "        super().setSystemPrompt(prompt)\n",
    "\n",
    "    def setUserPrompt(self, prompt:str):\n",
    "        super().setUserPrompt(prompt)\n",
    "        super().messageSet([\n",
    "            {\"role\": \"user\", \"content\": self.user_prompt}\n",
    "        ])\n",
    "\n",
    "    def getResult(self, format=None):\n",
    "        \"\"\"\n",
    "        format is sent as an adittional parameter {\"type\", format}\n",
    "        e.g. json_object\n",
    "        \"\"\"\n",
    "        response = self.llm.messages.create(\n",
    "            model=self.MODEL,\n",
    "            max_tokens=self.MAX_TOKENS,\n",
    "            temperature=self.temperature,\n",
    "            system=self.system_prompt,\n",
    "            messages=super().messageGet()\n",
    "        )\n",
    "        result = response.content[0].text\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4379f1c0-6eeb-4611-8f34-a7303546ab71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import google.generativeai\n",
    "\n",
    "class Gemini_Wrapper(LLM_Wrapper):\n",
    "\n",
    "    MODEL:str = 'gemini-1.5-flash'\n",
    "    llm:google.generativeai.GenerativeModel\n",
    "\n",
    "    def __init__(self, system_prompt:str, user_prompt:str):\n",
    "        super().__init__(system_prompt, user_prompt, \"GOOGLE_API_KEY\")\n",
    "        self.llm = google.generativeai.GenerativeModel(\n",
    "            model_name=self.MODEL,\n",
    "            system_instruction=self.system_prompt\n",
    "        )\n",
    "        google.generativeai.configure(api_key=super().getKey())\n",
    "\n",
    "    def setSystemPrompt(self, prompt:str):\n",
    "        super().setSystemPrompt(prompt)\n",
    "\n",
    "    def setUserPrompt(self, prompt:str):\n",
    "        super().setUserPrompt(prompt)\n",
    "\n",
    "    def getResult(self, format=None):\n",
    "        \"\"\"\n",
    "        format is sent as an adittional parameter {\"type\", format}\n",
    "        e.g. json_object\n",
    "        \"\"\"\n",
    "        response = self.llm.generate_content(self.user_prompt)\n",
    "        result = response.text\n",
    "        return result"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}