From 2188041f7b637d86126b8221b6224bfa8ce02e40 Mon Sep 17 00:00:00 2001 From: zestysoft Date: Tue, 12 Mar 2024 15:12:35 -0700 Subject: [PATCH] Add code to use openai_base_url and use OpenAI's model lister function Signed-off-by: zestysoft --- installer/client/cli/utils.py | 87 ++++++++++++++++------------------- 1 file changed, 39 insertions(+), 48 deletions(-) diff --git a/installer/client/cli/utils.py b/installer/client/cli/utils.py index defcd59..8e8c2ce 100644 --- a/installer/client/cli/utils.py +++ b/installer/client/cli/utils.py @@ -1,6 +1,6 @@ import requests import os -from openai import OpenAI +from openai import OpenAI, APIConnectionError import asyncio import pyperclip import sys @@ -36,12 +36,10 @@ class Standalone: # Expand the tilde to the full path env_file = os.path.expanduser(env_file) load_dotenv(env_file) - try: - apikey = os.environ["OPENAI_API_KEY"] - self.client = OpenAI() - self.client.api_key = apikey - except: - print("No API key found. Use the --apikey option to set the key") + assert 'OPENAI_API_KEY' in os.environ, "Error: OPENAI_API_KEY not found in environment variables. Please run fabric --setup and add a key." + api_key = os.environ['OPENAI_API_KEY'] + base_url = os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/') + self.client = OpenAI(api_key=api_key, base_url=base_url) self.local = False self.config_pattern_directory = config_directory self.pattern = pattern @@ -267,28 +265,24 @@ class Standalone: fullOllamaList = [] claudeList = ['claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-2.1'] try: - headers = { - "Authorization": f"Bearer {self.client.api_key}" - } - response = requests.get( - "https://api.openai.com/v1/models", headers=headers) - - if response.status_code == 200: - models = response.json().get("data", []) - # Filter only gpt models - gpt_models = [model for model in models if model.get( - "id", "").startswith(("gpt"))] - # Sort the models alphabetically by their ID - sorted_gpt_models = sorted( - gpt_models, key=lambda x: x.get("id")) - - for model in sorted_gpt_models: - gptlist.append(model.get("id")) + models = [model.id for model in self.client.models.list().data] + except APIConnectionError as e: + if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '": + print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.") + else: - print(f"Failed to fetch models: HTTP {response.status_code}") - sys.exit() - except: - print('No OpenAI API key found. Please run fabric --setup and add the key if you wish to interact with openai') + print(f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}") + sys.exit() + except Exception as e: + print(f"Error: {getattr(e.__context__, 'args', [''])[0]}") + sys.exit() + if "/" in models[0] or "\\" in models[0]: + # lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash + gptlist = [item[item.rfind("/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models] + else: + # Keep items that start with "gpt" + gptlist = [item for item in models if item.startswith("gpt")] + gptlist.sort() import ollama try: default_modelollamaList = ollama.list()['models'] @@ -436,27 +430,24 @@ class Setup: pass def fetch_available_models(self): - headers = { - "Authorization": f"Bearer {self.openaiapi_key}" - } - - response = requests.get( - "https://api.openai.com/v1/models", headers=headers) - - if response.status_code == 200: - models = response.json().get("data", []) - # Filter only gpt models - gpt_models = [model for model in models if model.get( - "id", "").startswith(("gpt"))] - # Sort the models alphabetically by their ID - sorted_gpt_models = sorted( - gpt_models, key=lambda x: x.get("id")) - - for model in sorted_gpt_models: - self.gptlist.append(model.get("id")) - else: - print(f"Failed to fetch models: HTTP {response.status_code}") + try: + models = [model.id for model in self.client.models.list().data] + except APIConnectionError as e: + if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '": + print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.") + else: + print(f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}") sys.exit() + except Exception as e: + print(f"Error: {getattr(e.__context__, 'args', [''])[0]}") + sys.exit() + if "/" in models[0] or "\\" in models[0]: + # lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash + self.gptlist = [item[item.rfind("/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models] + else: + # Keep items that start with "gpt" + self.gptlist = [item for item in models if item.startswith("gpt")] + self.gptlist.sort() import ollama try: default_modelollamaList = ollama.list()['models']