Browse Source

[FEATURE] Use LiteLLM for inference

Instead of using OpenAI, the user can now decide which LLM Provider or Model is supposed to be used for inference.
See LiteLLM-Documentation for further information and compatibility: https://docs.litellm.ai/docs/
pull/79/head
Chris 1 year ago
parent
commit
b40bb87e1e
  1. 88
      client/fabric/utils.py

88
client/fabric/utils.py

@ -1,12 +1,12 @@
import requests import requests
import os import os
from openai import OpenAI
import pyperclip import pyperclip
import sys import sys
import platform import platform
from dotenv import load_dotenv from dotenv import load_dotenv
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from tqdm import tqdm from tqdm import tqdm
import litellm
current_directory = os.path.dirname(os.path.realpath(__file__)) current_directory = os.path.dirname(os.path.realpath(__file__))
config_directory = os.path.expanduser("~/.config/fabric") config_directory = os.path.expanduser("~/.config/fabric")
@ -25,29 +25,36 @@ class Standalone:
Returns: Returns:
None None
Raises:
KeyError: If the "OPENAI_API_KEY" is not found in the environment variables.
FileNotFoundError: If no API key is found in the environment variables.
""" """
# Expand the tilde to the full path
env_file = os.path.expanduser(env_file)
load_dotenv(env_file)
try: try:
apikey = os.environ["OPENAI_API_KEY"] # Expand the tilde to the full path
self.client = OpenAI() env_file = os.path.expanduser(env_file)
self.client.api_key = apikey load_dotenv(env_file)
except KeyError:
print("OPENAI_API_KEY not found in environment variables.")
self.args = args
self.config_pattern_directory = config_directory
model, llm_provider, _, _ = litellm.get_llm_provider(args.model)
if not model:
raise ValueError(
"""Model not found. Please check the model name. Use --listmodels to see available models or check the documentation for more information."""
)
if not llm_provider:
raise ValueError("LLM Provider not found. Please check the documentation for more information.")
self.model = args.model
self.llm_provider = llm_provider
self.pattern = pattern
except FileNotFoundError: except FileNotFoundError:
print("No API key found. Use the --apikey option to set the key") print("No environment file found. Please use the --setup option to initialize the required environment variables.")
sys.exit() sys.exit()
self.config_pattern_directory = config_directory
self.pattern = pattern except Exception as e:
self.args = args print(f"An error occurred: {e}")
self.model = args.model sys.exit()
def streamMessage(self, input_data: str): def streamMessage(self, input_data: str):
""" Stream a message and handle exceptions. """ Stream a message and handle exceptions.
@ -80,15 +87,17 @@ class Standalone:
else: else:
messages = [user_message] messages = [user_message]
try: try:
stream = self.client.chat.completions.create( arguments = {
model=self.model, "model": self.model,
messages=messages, "messages": messages,
temperature=0.0, "stream": True,
top_p=1, "temperature": 0.0,
frequency_penalty=0.1, "top_p": 1,
presence_penalty=0.1, }
stream=True, if self.llm_provider == "openai":
) arguments["frequency_penalty"] = 0.1
arguments["presence_penalty"] = 0.1
stream = litellm.completion(**arguments)
for chunk in stream: for chunk in stream:
if chunk.choices[0].delta.content is not None: if chunk.choices[0].delta.content is not None:
char = chunk.choices[0].delta.content char = chunk.choices[0].delta.content
@ -139,14 +148,16 @@ class Standalone:
else: else:
messages = [user_message] messages = [user_message]
try: try:
response = self.client.chat.completions.create( arguments = {
model=self.model, "model": self.model,
messages=messages, "messages": messages,
temperature=0.0, "temperature": 0.0,
top_p=1, "top_p": 1,
frequency_penalty=0.1, }
presence_penalty=0.1, if self.llm_provider == "openai":
) arguments["frequency_penalty"] = 0.1
arguments["presence_penalty"] = 0.1
response = litellm.completion(**arguments)
print(response.choices[0].message.content) print(response.choices[0].message.content)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
@ -158,10 +169,9 @@ class Standalone:
f.write(response.choices[0].message.content) f.write(response.choices[0].message.content)
def fetch_available_models(self): def fetch_available_models(self):
headers = { """Fetch the available models from the OpenAI API."""
"Authorization": f"Bearer { self.client.api_key }" headers = {"Authorization": f"Bearer { os.environ.get('OPENAI_API_KEY') }"}
}
response = requests.get("https://api.openai.com/v1/models", headers=headers) response = requests.get("https://api.openai.com/v1/models", headers=headers)
if response.status_code == 200: if response.status_code == 200:

Loading…
Cancel
Save