diff --git a/installer/client/cli/fabric.py b/installer/client/cli/fabric.py index cbbc68a..0fce58a 100755 --- a/installer/client/cli/fabric.py +++ b/installer/client/cli/fabric.py @@ -39,6 +39,14 @@ def main(): parser.add_argument( "--list", "-l", help="List available patterns", action="store_true" ) + parser.add_argument( + '--temp', help="set the temperature for the model. Default is 0", default=0, type=float) + parser.add_argument( + '--top_p', help="set the top_p for the model. Default is 1", default=1, type=float) + parser.add_argument( + '--frequency_penalty', help="set the frequency penalty for the model. Default is 0.1", default=0.1, type=float) + parser.add_argument( + '--presence_penalty', help="set the presence penalty for the model. Default is 0.1", default=0.1, type=float) parser.add_argument( "--update", "-u", help="Update patterns. NOTE: This will revert the default model to gpt4-turbo. please run --changeDefaultModel to once again set default model", action="store_true") parser.add_argument("--pattern", "-p", help="The pattern (prompt) to use") diff --git a/installer/client/cli/utils.py b/installer/client/cli/utils.py index 9aa32a9..14e6a7d 100644 --- a/installer/client/cli/utils.py +++ b/installer/client/cli/utils.py @@ -87,7 +87,7 @@ class Standalone: max_tokens=4096, system=system, messages=[user], - model=self.model, temperature=0.0, top_p=1.0 + model=self.model, temperature=self.args.temp, top_p=self.args.top_p ) as stream: async for text in stream.text_stream: print(text, end="", flush=True) @@ -104,7 +104,7 @@ class Standalone: system=system, messages=[user], model=self.model, - temperature=0.0, top_p=1.0 + temperature=self.args.temp, top_p=self.args.top_p ) print(message.content[0].text) copy = self.args.copy @@ -162,10 +162,10 @@ class Standalone: stream = self.client.chat.completions.create( model=self.model, messages=messages, - temperature=0.0, - top_p=1, - frequency_penalty=0.1, - presence_penalty=0.1, + temperature=self.args.temp, + top_p=self.args.top_p, + frequency_penalty=self.args.frequency_penalty, + presence_penalty=self.args.presence_penalty, stream=True, ) for chunk in stream: @@ -247,10 +247,10 @@ class Standalone: response = self.client.chat.completions.create( model=self.model, messages=messages, - temperature=0.0, - top_p=1, - frequency_penalty=0.1, - presence_penalty=0.1, + temperature=self.args.temp, + top_p=self.args.top_p, + frequency_penalty=self.args.frequency_penalty, + presence_penalty=self.args.presence_penalty, ) print(response.choices[0].message.content) if self.args.copy: diff --git a/installer/client/gui/index.html b/installer/client/gui/index.html index d84095f..9c539e7 100644 --- a/installer/client/gui/index.html +++ b/installer/client/gui/index.html @@ -39,6 +39,12 @@ +
Dark @@ -91,6 +97,56 @@ />
+ diff --git a/installer/client/gui/main.js b/installer/client/gui/main.js index ae93eee..0d14c33 100644 --- a/installer/client/gui/main.js +++ b/installer/client/gui/main.js @@ -286,7 +286,16 @@ async function getPatternContent(patternName) { } } -async function ollamaMessage(system, user, model, event) { +async function ollamaMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event +) { ollama = new Ollama.Ollama(); const userMessage = { role: "user", @@ -296,6 +305,10 @@ async function ollamaMessage(system, user, model, event) { const response = await ollama.chat({ model: model, messages: [systemMessage, userMessage], + temperature: temperature, + top_p: topP, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, stream: true, }); let responseMessage = ""; @@ -309,13 +322,26 @@ async function ollamaMessage(system, user, model, event) { } } -async function openaiMessage(system, user, model, event) { +async function openaiMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event +) { const userMessage = { role: "user", content: user }; const systemMessage = { role: "system", content: system }; const stream = await openai.chat.completions.create( { model: model, messages: [systemMessage, userMessage], + temperature: temperature, + top_p: topP, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, stream: true, }, { responseType: "stream" } @@ -334,7 +360,7 @@ async function openaiMessage(system, user, model, event) { event.reply("model-response-end", responseMessage); } -async function claudeMessage(system, user, model, event) { +async function claudeMessage(system, user, model, temperature, topP, event) { if (!claude) { event.reply( "model-response-error", @@ -351,8 +377,8 @@ async function claudeMessage(system, user, model, event) { max_tokens: 4096, messages: [userMessage], stream: true, - temperature: 0.0, - top_p: 1.0, + temperature: temperature, + top_p: topP, }); let responseMessage = ""; for await (const chunk of response) { @@ -409,32 +435,62 @@ function createWindow() { }); } -ipcMain.on("start-query", async (event, system, user, model) => { - if (system == null || user == null || model == null) { - console.error("Received null for system, user message, or model"); - event.reply( - "model-response-error", - "Error: System, user message, or model is null." - ); - return; - } +ipcMain.on( + "start-query", + async ( + event, + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty + ) => { + if (system == null || user == null || model == null) { + console.error("Received null for system, user message, or model"); + event.reply( + "model-response-error", + "Error: System, user message, or model is null." + ); + return; + } - try { - const _gptModels = allModels.gptModels.map((model) => model.id); - if (allModels.claudeModels.includes(model)) { - await claudeMessage(system, user, model, event); - } else if (_gptModels.includes(model)) { - await openaiMessage(system, user, model, event); - } else if (allModels.ollamaModels.includes(model)) { - await ollamaMessage(system, user, model, event); - } else { - event.reply("model-response-error", "Unsupported model: " + model); + try { + const _gptModels = allModels.gptModels.map((model) => model.id); + if (allModels.claudeModels.includes(model)) { + await claudeMessage(system, user, model, temperature, topP, event); + } else if (_gptModels.includes(model)) { + await openaiMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event + ); + } else if (allModels.ollamaModels.includes(model)) { + await ollamaMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event + ); + } else { + event.reply("model-response-error", "Unsupported model: " + model); + } + } catch (error) { + console.error("Error querying model:", error); + event.reply("model-response-error", "Error querying model."); } - } catch (error) { - console.error("Error querying model:", error); - event.reply("model-response-error", "Error querying model."); } -}); +); ipcMain.handle("create-pattern", async (event, patternName, patternContent) => { try { diff --git a/installer/client/gui/static/js/index.js b/installer/client/gui/static/js/index.js index dc7218b..e66a7e6 100644 --- a/installer/client/gui/static/js/index.js +++ b/installer/client/gui/static/js/index.js @@ -14,6 +14,22 @@ document.addEventListener("DOMContentLoaded", async function () { const updatePatternButton = document.getElementById("createPattern"); const patternCreator = document.getElementById("patternCreator"); const submitPatternButton = document.getElementById("submitPattern"); + const fineTuningButton = document.getElementById("fineTuningButton"); + const fineTuningSection = document.getElementById("fineTuningSection"); + const temperatureSlider = document.getElementById("temperatureSlider"); + const temperatureValue = document.getElementById("temperatureValue"); + const topPSlider = document.getElementById("topPSlider"); + const topPValue = document.getElementById("topPValue"); + const frequencyPenaltySlider = document.getElementById( + "frequencyPenaltySlider" + ); + const frequencyPenaltyValue = document.getElementById( + "frequencyPenaltyValue" + ); + const presencePenaltySlider = document.getElementById( + "presencePenaltySlider" + ); + const presencePenaltyValue = document.getElementById("presencePenaltyValue"); const myForm = document.getElementById("my-form"); const copyButton = document.createElement("button"); @@ -55,6 +71,10 @@ document.addEventListener("DOMContentLoaded", async function () { } async function submitQuery(userInputValue) { + const temperature = parseFloat(temperatureSlider.value); + const topP = parseFloat(topPSlider.value); + const frequencyPenalty = parseFloat(frequencyPenaltySlider.value); + const presencePenalty = parseFloat(presencePenaltySlider.value); userInput.value = ""; // Clear the input after submitting const systemCommand = await window.electronAPI.invoke( "get-pattern-content", @@ -70,7 +90,11 @@ document.addEventListener("DOMContentLoaded", async function () { "start-query", systemCommand, userInputValue, - selectedModel + selectedModel, + temperature, + topP, + frequencyPenalty, + presencePenalty ); } @@ -222,6 +246,27 @@ document.addEventListener("DOMContentLoaded", async function () { submitQuery(userInputValue); }); + fineTuningButton.addEventListener("click", function (e) { + e.preventDefault(); + fineTuningSection.classList.toggle("hidden"); + }); + + temperatureSlider.addEventListener("input", function () { + temperatureValue.textContent = this.value; + }); + + topPSlider.addEventListener("input", function () { + topPValue.textContent = this.value; + }); + + frequencyPenaltySlider.addEventListener("input", function () { + frequencyPenaltyValue.textContent = this.value; + }); + + presencePenaltySlider.addEventListener("input", function () { + presencePenaltyValue.textContent = this.value; + }); + submitPatternButton.addEventListener("click", async () => { const patternName = document.getElementById("patternName").value; const patternText = document.getElementById("patternBody").value;