|
|
|
@ -8,6 +8,7 @@ import platform
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
import zipfile |
|
|
|
|
import tempfile |
|
|
|
|
import re |
|
|
|
|
import shutil |
|
|
|
|
|
|
|
|
|
current_directory = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
@ -424,17 +425,24 @@ class Setup:
|
|
|
|
|
self.gptlist = [] |
|
|
|
|
self.fullOllamaList = [] |
|
|
|
|
self.claudeList = ['claude-3-opus-20240229'] |
|
|
|
|
load_dotenv(self.env_file) |
|
|
|
|
try: |
|
|
|
|
openaiapikey = os.environ["OPENAI_API_KEY"] |
|
|
|
|
self.openaiapi_key = openaiapikey |
|
|
|
|
except KeyError: |
|
|
|
|
print("OPENAI_API_KEY not found in environment variables.") |
|
|
|
|
sys.exit() |
|
|
|
|
self.fetch_available_models() |
|
|
|
|
|
|
|
|
|
def fetch_available_models(self): |
|
|
|
|
headers = { |
|
|
|
|
"Authorization": f"Bearer {self.client.api_key}" |
|
|
|
|
"Authorization": f"Bearer {self.openaiapi_key}" |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
response = requests.get( |
|
|
|
|
"https://api.openai.com/v1/models", headers=headers) |
|
|
|
|
|
|
|
|
|
if response.status_code == 200: |
|
|
|
|
print("OpenAI GPT models:\n") |
|
|
|
|
models = response.json().get("data", []) |
|
|
|
|
# Filter only gpt models |
|
|
|
|
gpt_models = [model for model in models if model.get( |
|
|
|
@ -444,18 +452,19 @@ class Setup:
|
|
|
|
|
gpt_models, key=lambda x: x.get("id")) |
|
|
|
|
|
|
|
|
|
for model in sorted_gpt_models: |
|
|
|
|
print(model.get("id")) |
|
|
|
|
self.gptlist.append(model.get("id")) |
|
|
|
|
print("\nLocal Ollama models:") |
|
|
|
|
import ollama |
|
|
|
|
default_modelollamaList = ollama.list()['models'] |
|
|
|
|
for model in ollamaList: |
|
|
|
|
print(model['name'].rstrip(":latest")) |
|
|
|
|
self.fullOllamaList.append(model['name'].rstrip(":latest")) |
|
|
|
|
print("\nClaude models:") |
|
|
|
|
print("claude-3-opus-20240229") |
|
|
|
|
else: |
|
|
|
|
print(f"Failed to fetch models: HTTP {response.status_code}") |
|
|
|
|
sys.exit() |
|
|
|
|
import ollama |
|
|
|
|
try: |
|
|
|
|
default_modelollamaList = ollama.list()['models'] |
|
|
|
|
for model in default_modelollamaList: |
|
|
|
|
self.fullOllamaList.append(model['name'].rstrip(":latest")) |
|
|
|
|
except: |
|
|
|
|
self.fullOllamaList = [] |
|
|
|
|
allmodels = self.gptlist + self.fullOllamaList + self.claudeList |
|
|
|
|
return allmodels |
|
|
|
|
|
|
|
|
|
def api_key(self, api_key): |
|
|
|
|
""" Set the OpenAI API key in the environment file. |
|
|
|
@ -509,36 +518,69 @@ class Setup:
|
|
|
|
|
with open(self.env_file, "w") as f: |
|
|
|
|
f.write(f"CLAUDE_API_KEY={claude_key}") |
|
|
|
|
|
|
|
|
|
def update_fabric_command(self, line, model): |
|
|
|
|
fabric_command_regex = re.compile( |
|
|
|
|
r"(fabric --pattern\s+\S+.*?)( --claude| --local)?'") |
|
|
|
|
match = fabric_command_regex.search(line) |
|
|
|
|
if match: |
|
|
|
|
base_command = match.group(1) |
|
|
|
|
# Provide a default value for current_flag |
|
|
|
|
current_flag = match.group(2) if match.group(2) else "" |
|
|
|
|
new_flag = "" |
|
|
|
|
if model in self.claudeList: |
|
|
|
|
new_flag = " --claude" |
|
|
|
|
elif model in self.fullOllamaList: |
|
|
|
|
new_flag = " --local" |
|
|
|
|
# Update the command if the new flag is different or to remove an existing flag. |
|
|
|
|
# Ensure to add the closing quote that was part of the original regex |
|
|
|
|
return f"{base_command}{new_flag}'\n" |
|
|
|
|
else: |
|
|
|
|
return line # Return the line unmodified if no match is found. |
|
|
|
|
|
|
|
|
|
def update_fabric_alias(self, line, model): |
|
|
|
|
fabric_alias_regex = re.compile( |
|
|
|
|
r"(alias fabric='[^']+?)( --claude| --local)?'") |
|
|
|
|
match = fabric_alias_regex.search(line) |
|
|
|
|
if match: |
|
|
|
|
base_command, current_flag = match.groups() |
|
|
|
|
new_flag = "" |
|
|
|
|
if model in self.claudeList: |
|
|
|
|
new_flag = " --claude" |
|
|
|
|
elif model in self.fullOllamaList: |
|
|
|
|
new_flag = " --local" |
|
|
|
|
# Update the alias if the new flag is different or to remove an existing flag. |
|
|
|
|
return f"{base_command}{new_flag}'\n" |
|
|
|
|
else: |
|
|
|
|
return line # Return the line unmodified if no match is found. |
|
|
|
|
|
|
|
|
|
def default_model(self, model): |
|
|
|
|
""" Set the default model in the environment file. |
|
|
|
|
"""Set the default model in the environment file. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
model (str): The model to be set. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
model = model.strip() |
|
|
|
|
if os.path.exists(self.env_file) and model: |
|
|
|
|
with open(self.env_file, "r") as f: |
|
|
|
|
lines = f.readlines() |
|
|
|
|
with open(self.env_file, "w") as f: |
|
|
|
|
for line in lines: |
|
|
|
|
if "DEFAULT_MODEL" not in line: |
|
|
|
|
f.write(line) |
|
|
|
|
f.write(f"DEFAULT_MODEL={model}") |
|
|
|
|
elif model: |
|
|
|
|
with open(self.env_file, "w") as f: |
|
|
|
|
f.write(f"DEFAULT_MODEL={model}") |
|
|
|
|
else: |
|
|
|
|
with open(self.env_file, "r") as f: |
|
|
|
|
lines = f.readlines() |
|
|
|
|
with open(self.env_file, "w") as f: |
|
|
|
|
for line in lines: |
|
|
|
|
if "DEFAULT_MODEL" not in line: |
|
|
|
|
f.write(line) |
|
|
|
|
import re |
|
|
|
|
plain_fabric_regex = re.compile( |
|
|
|
|
r"(fabric='.*fabric)( --claude| --local)?'" |
|
|
|
|
fabric_regex = re.compile(r"(fabric --pattern.*)( --claude|--local)'") |
|
|
|
|
if model: |
|
|
|
|
# Write or update the DEFAULT_MODEL in env_file |
|
|
|
|
if os.path.exists(self.env_file): |
|
|
|
|
with open(self.env_file, "r") as f: |
|
|
|
|
lines = f.readlines() |
|
|
|
|
with open(self.env_file, "w") as f: |
|
|
|
|
found = False |
|
|
|
|
for line in lines: |
|
|
|
|
if line.startswith("DEFAULT_MODEL"): |
|
|
|
|
f.write(f"DEFAULT_MODEL={model}\n") |
|
|
|
|
found = True |
|
|
|
|
else: |
|
|
|
|
f.write(line) |
|
|
|
|
if not found: |
|
|
|
|
f.write(f"DEFAULT_MODEL={model}\n") |
|
|
|
|
else: |
|
|
|
|
with open(self.env_file, "w") as f: |
|
|
|
|
f.write(f"DEFAULT_MODEL={model}\n") |
|
|
|
|
|
|
|
|
|
# Compile regular expressions outside of the loop for efficiency |
|
|
|
|
|
|
|
|
|
user_home = os.path.expanduser("~") |
|
|
|
|
sh_config = None |
|
|
|
|
# Check for shell configuration files |
|
|
|
@ -552,17 +594,14 @@ class Setup:
|
|
|
|
|
lines = f.readlines() |
|
|
|
|
with open(sh_config, "w") as f: |
|
|
|
|
for line in lines: |
|
|
|
|
# Remove existing --claude or --local |
|
|
|
|
modified_line = re.sub(fabric_regex, r"\1'", line) |
|
|
|
|
modified_line = line |
|
|
|
|
# Update existing fabric commands |
|
|
|
|
if "fabric --pattern" in line: |
|
|
|
|
if model in self.claudeList: |
|
|
|
|
whole_thing = plain_fabric_regex.search(line)[0] |
|
|
|
|
beginning_match = plain_fabric_regex.search(line)[1] |
|
|
|
|
modified_line = re.sub( |
|
|
|
|
fabric_regex, r"\1 --claude'", line) |
|
|
|
|
elif model in self.fullOllamaList: |
|
|
|
|
modified_line = re.sub( |
|
|
|
|
fabric_regex, r"\1 --local'", line) |
|
|
|
|
modified_line = self.update_fabric_command( |
|
|
|
|
modified_line, model) |
|
|
|
|
elif "fabric=" in line: |
|
|
|
|
modified_line = self.update_fabric_alias( |
|
|
|
|
modified_line, model) |
|
|
|
|
f.write(modified_line) |
|
|
|
|
print(f"""Default model changed to { |
|
|
|
|
model}. Please restart your terminal to use it.""") |
|
|
|
|