{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a0adab93-e569-4af0-80f1-ce5b7a116507", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "id": "9f583520-3c49-4e79-84ae-02bfc57f1e49", "metadata": {}, "outputs": [], "source": [ "# Creating a set of classes to simplify LLM use\n", "\n", "from abc import ABC, abstractmethod\n", "from dotenv import load_dotenv\n", "# Imports for type definition\n", "from collections.abc import MutableSequence\n", "from typing import TypedDict\n", "\n", "class LLM_Wrapper(ABC):\n", " \"\"\"\n", " The parent (abstract) class to specific LLM classes, normalising and providing common \n", " and simplified ways to call LLMs while adding some level of abstraction on\n", " specifics\n", " \"\"\"\n", "\n", " MessageEntry = TypedDict('MessageEntry', {'role': str, 'content': str})\n", " \n", " system_prompt: str # The system prompt used for the LLM\n", " user_prompt: str # The user prompt\n", " __api_key: str # The (private) api key\n", " temperature: float = 0.5 # Default temperature\n", " __msg: MutableSequence[MessageEntry] # Message builder\n", "\n", " def __init__(self, system_prompt:str, user_prompt:str, env_apikey_var:str=None):\n", " \"\"\"\n", " env_apikey_var: str # The name of the env variable where to find the api_key\n", " # We store the retrieved api_key for future calls\n", " \"\"\"\n", " self.system_prompt = system_prompt\n", " self.user_prompt = user_prompt\n", " if env_apikey_var:\n", " load_dotenv(override=True)\n", " self.__api_key = os.getenv(env_apikey_var)\n", "\n", " # # API Key format check\n", " # if env_apikey_var and self.__api_key:\n", " # print(f\"API Key exists and begins {self.__api_key[:8]}\")\n", " # else:\n", " # print(\"API Key not set\")\n", " \n", " def setSystemPrompt(self, prompt:str):\n", " self.system_prompt = prompt\n", "\n", " def setUserPrompt(self, prompt:str):\n", " self.user_prompt = prompt\n", "\n", " def setTemperature(self, temp:float):\n", " self.temperature = temp\n", "\n", " def getKey(self) -> str:\n", " return self.__api_key\n", "\n", " def messageSet(self, message: MutableSequence[MessageEntry]):\n", " self.__msg = message\n", "\n", " def messageAppend(self, role: str, content: str):\n", " self.__msg.append(\n", " {\"role\": role, \"content\": content}\n", " )\n", "\n", " def messageGet(self) -> MutableSequence[MessageEntry]:\n", " return self.__msg\n", " \n", " @abstractmethod\n", " def getResult(self):\n", " pass\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a707f3ef-8696-44a9-943e-cfbce24b9fde", "metadata": {}, "outputs": [], "source": [ "from openai import OpenAI\n", "\n", "class GPT_Wrapper(LLM_Wrapper):\n", "\n", " MODEL:str = 'gpt-4o-mini'\n", " llm:OpenAI\n", "\n", " def __init__(self, system_prompt:str, user_prompt:str):\n", " super().__init__(system_prompt, user_prompt, \"OPENAI_API_KEY\")\n", " self.llm = OpenAI()\n", " super().messageSet([\n", " {\"role\": \"system\", \"content\": self.system_prompt},\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", "\n", " def setSystemPrompt(self, prompt:str):\n", " super().setSystemPrompt(prompt)\n", " super().messageSet([\n", " {\"role\": \"system\", \"content\": self.system_prompt},\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", " def setUserPrompt(self, prompt:str):\n", " super().setUserPrompt(prompt)\n", " super().messageSet([\n", " {\"role\": \"system\", \"content\": self.system_prompt},\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", " def getResult(self, format=None):\n", " \"\"\"\n", " format is sent as an adittional parameter {\"type\", format}\n", " e.g. json_object\n", " \"\"\"\n", " if format:\n", " response = self.llm.chat.completions.create(\n", " model=self.MODEL,\n", " messages=super().messageGet(),\n", " temperature=self.temperature,\n", " response_format={\"type\": \"json_object\"}\n", " )\n", " if format == \"json_object\":\n", " result = json.loads(response.choices[0].message.content)\n", " else:\n", " result = response.choices[0].message.content\n", " else:\n", " response = self.llm.chat.completions.create(\n", " model=self.MODEL,\n", " messages=super().messageGet(),\n", " temperature=self.temperature\n", " )\n", " result = response.choices[0].message.content\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "id": "a8529004-0d6a-480c-9634-7d51498255fe", "metadata": {}, "outputs": [], "source": [ "import ollama\n", "\n", "class Ollama_Wrapper(LLM_Wrapper):\n", "\n", " MODEL:str = 'llama3.2'\n", "\n", " def __init__(self, system_prompt:str, user_prompt:str):\n", " super().__init__(system_prompt, user_prompt, None)\n", " self.llm=ollama\n", " super().messageSet([\n", " {\"role\": \"system\", \"content\": self.system_prompt},\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", "\n", " def setSystemPrompt(self, prompt:str):\n", " super().setSystemPrompt(prompt)\n", " super().messageSet([\n", " {\"role\": \"system\", \"content\": self.system_prompt},\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", " def setUserPrompt(self, prompt:str):\n", " super().setUserPrompt(prompt)\n", " super().messageSet([\n", " {\"role\": \"system\", \"content\": self.system_prompt},\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", " def getResult(self, format=None):\n", " \"\"\"\n", " format is sent as an adittional parameter {\"type\", format}\n", " e.g. json_object\n", " \"\"\"\n", " response = self.llm.chat(\n", " model=self.MODEL, \n", " messages=super().messageGet()\n", " )\n", " result = response['message']['content']\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "id": "f25ffb7e-0132-46cb-ad5b-18a300a7eb51", "metadata": {}, "outputs": [], "source": [ "import anthropic\n", "\n", "class Claude_Wrapper(LLM_Wrapper):\n", "\n", " MODEL:str = 'claude-3-5-haiku-20241022'\n", " MAX_TOKENS:int = 200\n", " llm:anthropic.Anthropic\n", "\n", " def __init__(self, system_prompt:str, user_prompt:str):\n", " super().__init__(system_prompt, user_prompt, \"ANTHROPIC_API_KEY\")\n", " self.llm = anthropic.Anthropic()\n", " super().messageSet([\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", " def setSystemPrompt(self, prompt:str):\n", " super().setSystemPrompt(prompt)\n", "\n", " def setUserPrompt(self, prompt:str):\n", " super().setUserPrompt(prompt)\n", " super().messageSet([\n", " {\"role\": \"user\", \"content\": self.user_prompt}\n", " ])\n", "\n", " def getResult(self, format=None):\n", " \"\"\"\n", " format is sent as an adittional parameter {\"type\", format}\n", " e.g. json_object\n", " \"\"\"\n", " response = self.llm.messages.create(\n", " model=self.MODEL,\n", " max_tokens=self.MAX_TOKENS,\n", " temperature=self.temperature,\n", " system=self.system_prompt,\n", " messages=super().messageGet()\n", " )\n", " result = response.content[0].text\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "id": "4379f1c0-6eeb-4611-8f34-a7303546ab71", "metadata": {}, "outputs": [], "source": [ "import google.generativeai\n", "\n", "class Gemini_Wrapper(LLM_Wrapper):\n", "\n", " MODEL:str = 'gemini-1.5-flash'\n", " llm:google.generativeai.GenerativeModel\n", "\n", " def __init__(self, system_prompt:str, user_prompt:str):\n", " super().__init__(system_prompt, user_prompt, \"GOOGLE_API_KEY\")\n", " self.llm = google.generativeai.GenerativeModel(\n", " model_name=self.MODEL,\n", " system_instruction=self.system_prompt\n", " )\n", " google.generativeai.configure(api_key=super().getKey())\n", "\n", " def setSystemPrompt(self, prompt:str):\n", " super().setSystemPrompt(prompt)\n", "\n", " def setUserPrompt(self, prompt:str):\n", " super().setUserPrompt(prompt)\n", "\n", " def getResult(self, format=None):\n", " \"\"\"\n", " format is sent as an adittional parameter {\"type\", format}\n", " e.g. json_object\n", " \"\"\"\n", " response = self.llm.generate_content(self.user_prompt)\n", " result = response.text\n", " return result" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }