Browse Source

[FEATURE] Use LiteLLM for inference

Instead of using OpenAI, the user can now decide which LLM Provider or Model is supposed to be used for inference.
See LiteLLM-Documentation for further information and compatibility: https://docs.litellm.ai/docs/
pull/79/head
Chris 1 year ago
parent
commit
b40bb87e1e
  1. 88
      client/fabric/utils.py

88
client/fabric/utils.py

@ -1,12 +1,12 @@
import requests
import os
from openai import OpenAI
import pyperclip
import sys
import platform
from dotenv import load_dotenv
from requests.exceptions import HTTPError
from tqdm import tqdm
import litellm
current_directory = os.path.dirname(os.path.realpath(__file__))
config_directory = os.path.expanduser("~/.config/fabric")
@ -25,29 +25,36 @@ class Standalone:
Returns:
None
Raises:
KeyError: If the "OPENAI_API_KEY" is not found in the environment variables.
FileNotFoundError: If no API key is found in the environment variables.
"""
# Expand the tilde to the full path
env_file = os.path.expanduser(env_file)
load_dotenv(env_file)
try:
apikey = os.environ["OPENAI_API_KEY"]
self.client = OpenAI()
self.client.api_key = apikey
except KeyError:
print("OPENAI_API_KEY not found in environment variables.")
# Expand the tilde to the full path
env_file = os.path.expanduser(env_file)
load_dotenv(env_file)
self.args = args
self.config_pattern_directory = config_directory
model, llm_provider, _, _ = litellm.get_llm_provider(args.model)
if not model:
raise ValueError(
"""Model not found. Please check the model name. Use --listmodels to see available models or check the documentation for more information."""
)
if not llm_provider:
raise ValueError("LLM Provider not found. Please check the documentation for more information.")
self.model = args.model
self.llm_provider = llm_provider
self.pattern = pattern
except FileNotFoundError:
print("No API key found. Use the --apikey option to set the key")
print("No environment file found. Please use the --setup option to initialize the required environment variables.")
sys.exit()
self.config_pattern_directory = config_directory
self.pattern = pattern
self.args = args
self.model = args.model
except Exception as e:
print(f"An error occurred: {e}")
sys.exit()
def streamMessage(self, input_data: str):
""" Stream a message and handle exceptions.
@ -80,15 +87,17 @@ class Standalone:
else:
messages = [user_message]
try:
stream = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.0,
top_p=1,
frequency_penalty=0.1,
presence_penalty=0.1,
stream=True,
)
arguments = {
"model": self.model,
"messages": messages,
"stream": True,
"temperature": 0.0,
"top_p": 1,
}
if self.llm_provider == "openai":
arguments["frequency_penalty"] = 0.1
arguments["presence_penalty"] = 0.1
stream = litellm.completion(**arguments)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
char = chunk.choices[0].delta.content
@ -139,14 +148,16 @@ class Standalone:
else:
messages = [user_message]
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.0,
top_p=1,
frequency_penalty=0.1,
presence_penalty=0.1,
)
arguments = {
"model": self.model,
"messages": messages,
"temperature": 0.0,
"top_p": 1,
}
if self.llm_provider == "openai":
arguments["frequency_penalty"] = 0.1
arguments["presence_penalty"] = 0.1
response = litellm.completion(**arguments)
print(response.choices[0].message.content)
except Exception as e:
print(f"Error: {e}")
@ -158,10 +169,9 @@ class Standalone:
f.write(response.choices[0].message.content)
def fetch_available_models(self):
headers = {
"Authorization": f"Bearer { self.client.api_key }"
}
"""Fetch the available models from the OpenAI API."""
headers = {"Authorization": f"Bearer { os.environ.get('OPENAI_API_KEY') }"}
response = requests.get("https://api.openai.com/v1/models", headers=headers)
if response.status_code == 200:

Loading…
Cancel
Save