diff --git a/installer/client/cli/utils.py b/installer/client/cli/utils.py index a9a1e4e..ce38774 100644 --- a/installer/client/cli/utils.py +++ b/installer/client/cli/utils.py @@ -49,18 +49,21 @@ class Standalone: self.config_pattern_directory = config_directory self.pattern = pattern self.args = args - self.model = args.model self.claude = claude - if self.local: - if self.args.model == 'gpt-4-turbo-preview': - self.args.model = 'llama2' - if self.claude: - if self.args.model == 'gpt-4-turbo-preview': - self.model = 'claude-3-opus-20240229' + try: + self.model = os.environ["CUSTOM_MODEL"] + except: + self.model = args.model + if self.local: + if self.args.model == 'gpt-4-turbo-preview': + self.model = 'llama2' + if self.claude: + if self.args.model == 'gpt-4-turbo-preview': + self.model = 'claude-3-opus-20240229' async def localChat(self, messages): from ollama import AsyncClient - response = await AsyncClient().chat(model=self.args.model, messages=messages) + response = await AsyncClient().chat(model=self.model, messages=messages) print(response['message']['content']) async def localStream(self, messages): @@ -458,8 +461,40 @@ class Setup: f.write(line) f.write(f"CLAUDE_API_KEY={claude_key}") elif claude_key: + with open(self.env_file, "r") as r: + lines = r.readlines() + with open(self.env_file, "w") as w: + for line in lines: + if "CLAUDE_API_KEY" not in line: + w.write(line) + w.write(f"CLAUDE_API_KEY={claude_key}") + + def custom_model(self, model): + """ + Set the custom model in the environment file + + Args: + model (str): The model to be set. + Returns: + None + """ + model = model.strip() + if os.path.exists(self.env_file) and model: + with open(self.env_file, "r") as f: + lines = f.readlines() with open(self.env_file, "w") as f: - f.write(f"CLAUDE_API_KEY={claude_key}") + for line in lines: + if "CUSTOM_MODEL" not in line: + f.write(line) + f.write(f"CUSTOM_MODEL={model}") + elif model: + with open(self.env_file, "r") as r: + lines = r.readlines() + with open(self.env_file, "w") as w: + for line in lines: + if "CUSTOM_MODEL" not in line: + w.write(line) + w.write(f"CUSTOM_MODEL={model}") def patterns(self): """ Method to update patterns and exit the system. @@ -482,10 +517,13 @@ class Setup: print("Welcome to Fabric. Let's get started.") apikey = input( "Please enter your OpenAI API key. If you do not have one or if you have already entered it, press enter.\n") - self.api_key(apikey.strip()) - print("Please enter your claude API key. If you do not have one, or if you have already entered it, press enter.\n") - claudekey = input() - self.claude_key(claudekey.strip()) + self.api_key(apikey) + claudekey = input( + "Please enter your claude API key. If you do not have one, or if you have already entered it, press enter.\n") + self.claude_key(claudekey) + custom_model = input( + "Please enter your custom model. If you do not have one, or if you have already entered it, press enter. If none is entered, it will default to gpt-4-turbo-preview\n") + self.custom_model(custom_model) self.patterns()