diff --git a/client/fabric/utils.py b/client/fabric/utils.py index c9f397d..1fd05cd 100644 --- a/client/fabric/utils.py +++ b/client/fabric/utils.py @@ -1,12 +1,12 @@ import requests import os -from openai import OpenAI import pyperclip import sys import platform from dotenv import load_dotenv from requests.exceptions import HTTPError from tqdm import tqdm +import litellm current_directory = os.path.dirname(os.path.realpath(__file__)) config_directory = os.path.expanduser("~/.config/fabric") @@ -25,29 +25,36 @@ class Standalone: Returns: None - - Raises: - KeyError: If the "OPENAI_API_KEY" is not found in the environment variables. - FileNotFoundError: If no API key is found in the environment variables. """ - # Expand the tilde to the full path - env_file = os.path.expanduser(env_file) - load_dotenv(env_file) try: - apikey = os.environ["OPENAI_API_KEY"] - self.client = OpenAI() - self.client.api_key = apikey - except KeyError: - print("OPENAI_API_KEY not found in environment variables.") + # Expand the tilde to the full path + env_file = os.path.expanduser(env_file) + load_dotenv(env_file) + self.args = args + self.config_pattern_directory = config_directory + + model, llm_provider, _, _ = litellm.get_llm_provider(args.model) + if not model: + raise ValueError( + """Model not found. Please check the model name. Use --listmodels to see available models or check the documentation for more information.""" + ) + if not llm_provider: + raise ValueError("LLM Provider not found. Please check the documentation for more information.") + + self.model = args.model + self.llm_provider = llm_provider + self.pattern = pattern + except FileNotFoundError: - print("No API key found. Use the --apikey option to set the key") + print("No environment file found. Please use the --setup option to initialize the required environment variables.") sys.exit() - self.config_pattern_directory = config_directory - self.pattern = pattern - self.args = args - self.model = args.model + + except Exception as e: + print(f"An error occurred: {e}") + sys.exit() + def streamMessage(self, input_data: str): """ Stream a message and handle exceptions. @@ -80,15 +87,17 @@ class Standalone: else: messages = [user_message] try: - stream = self.client.chat.completions.create( - model=self.model, - messages=messages, - temperature=0.0, - top_p=1, - frequency_penalty=0.1, - presence_penalty=0.1, - stream=True, - ) + arguments = { + "model": self.model, + "messages": messages, + "stream": True, + "temperature": 0.0, + "top_p": 1, + } + if self.llm_provider == "openai": + arguments["frequency_penalty"] = 0.1 + arguments["presence_penalty"] = 0.1 + stream = litellm.completion(**arguments) for chunk in stream: if chunk.choices[0].delta.content is not None: char = chunk.choices[0].delta.content @@ -139,14 +148,16 @@ class Standalone: else: messages = [user_message] try: - response = self.client.chat.completions.create( - model=self.model, - messages=messages, - temperature=0.0, - top_p=1, - frequency_penalty=0.1, - presence_penalty=0.1, - ) + arguments = { + "model": self.model, + "messages": messages, + "temperature": 0.0, + "top_p": 1, + } + if self.llm_provider == "openai": + arguments["frequency_penalty"] = 0.1 + arguments["presence_penalty"] = 0.1 + response = litellm.completion(**arguments) print(response.choices[0].message.content) except Exception as e: print(f"Error: {e}") @@ -158,10 +169,9 @@ class Standalone: f.write(response.choices[0].message.content) def fetch_available_models(self): - headers = { - "Authorization": f"Bearer { self.client.api_key }" - } - + """Fetch the available models from the OpenAI API.""" + headers = {"Authorization": f"Bearer { os.environ.get('OPENAI_API_KEY') }"} + response = requests.get("https://api.openai.com/v1/models", headers=headers) if response.status_code == 200: