From 7fc4b7b4eb3457b8cd74b0c0d90134ccd6fb66a5 Mon Sep 17 00:00:00 2001 From: Ionite Date: Tue, 15 Aug 2023 01:04:17 -0400 Subject: [PATCH] Add CompletionProvider, debug option for loading --- StabilityMatrix.Avalonia/App.axaml.cs | 3 + .../Behaviors/TextEditorCompletionBehavior.cs | 24 ++-- .../Controls/CodeCompletion/CompletionList.cs | 7 +- .../CodeCompletion/CompletionWindow.axaml.cs | 29 ++++- .../Controls/PromptCard.axaml | 6 +- .../DesignData/DesignData.cs | 1 + .../Helpers/TagCsvParser.cs | 70 +++++++++++ .../TagCompletion/CompletionProvider.cs | 119 ++++++++++++++++++ .../TagCompletion/ICompletionProvider.cs | 28 +++++ .../StabilityMatrix.Avalonia.csproj | 4 + .../Inference/PromptCardViewModel.cs | 14 ++- .../ViewModels/SettingsViewModel.cs | 19 ++- .../Views/SettingsPage.axaml | 8 ++ 13 files changed, 302 insertions(+), 30 deletions(-) create mode 100644 StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs create mode 100644 StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs create mode 100644 StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs diff --git a/StabilityMatrix.Avalonia/App.axaml.cs b/StabilityMatrix.Avalonia/App.axaml.cs index 3274f04a..22800d44 100644 --- a/StabilityMatrix.Avalonia/App.axaml.cs +++ b/StabilityMatrix.Avalonia/App.axaml.cs @@ -32,9 +32,11 @@ using Polly.Timeout; using Refit; using Sentry; using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Controls.CodeCompletion; using StabilityMatrix.Avalonia.DesignData; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels; using StabilityMatrix.Avalonia.ViewModels.Dialogs; @@ -365,6 +367,7 @@ public sealed class App : Application services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); // Rich presence services.AddSingleton(); diff --git a/StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs b/StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs index 23a0dee3..0e8aa754 100644 --- a/StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs +++ b/StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs @@ -9,6 +9,7 @@ using AvaloniaEdit.Editing; using NLog; using StabilityMatrix.Avalonia.Controls.CodeCompletion; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.TagCompletion; using CompletionWindow = StabilityMatrix.Avalonia.Controls.CodeCompletion.CompletionWindow; namespace StabilityMatrix.Avalonia.Behaviors; @@ -22,13 +23,13 @@ public class TextEditorCompletionBehavior : Behavior private CompletionWindow? completionWindow; // ReSharper disable once MemberCanBePrivate.Global - public static readonly StyledProperty TextProperty = - AvaloniaProperty.Register(nameof(Text)); + public static readonly StyledProperty CompletionProviderProperty = + AvaloniaProperty.Register(nameof(CompletionProvider)); - public string Text + public ICompletionProvider CompletionProvider { - get => GetValue(TextProperty); - set => SetValue(TextProperty, value); + get => GetValue(CompletionProviderProperty); + set => SetValue(CompletionProviderProperty, value); } protected override void OnAttached() @@ -55,7 +56,7 @@ public class TextEditorCompletionBehavior : Behavior private CompletionWindow CreateCompletionWindow(TextArea textArea) { - var window = new CompletionWindow(textArea) + var window = new CompletionWindow(textArea, CompletionProvider) { WindowManagerAddShadowHint = false, CloseWhenCaretAtBeginning = true, @@ -66,13 +67,6 @@ public class TextEditorCompletionBehavior : Behavior IsFiltering = true } }; - - var completionList = window.CompletionList; - - completionList.CompletionData.Add(new CompletionData("item1")); - completionList.CompletionData.Add(new CompletionData("item2")); - completionList.CompletionData.Add(new CompletionData("item3")); - return window; } @@ -98,8 +92,8 @@ public class TextEditorCompletionBehavior : Behavior completionWindow = CreateCompletionWindow(textEditor.TextArea); completionWindow.StartOffset = tokenSegment.Offset; completionWindow.EndOffset = tokenSegment.EndOffset; - - completionWindow.CompletionList.SelectItem(token); + + completionWindow.UpdateQuery(token); completionWindow.Closed += delegate { diff --git a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs index 785ec041..c14302dc 100644 --- a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs +++ b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs @@ -284,14 +284,15 @@ public class CompletionList : TemplatedControl private void SelectItemFiltering(string query) { // if the user just typed one more character, don't filter all data but just filter what we are already displaying - var listToFilter = + /*var listToFilter = _currentList != null && !string.IsNullOrEmpty(_currentText) && !string.IsNullOrEmpty(query) && query.StartsWith(_currentText, StringComparison.Ordinal) ? _currentList - : _completionData; - + : _completionData;*/ + var listToFilter = _completionData; + var matchingItems = from item in listToFilter let quality = GetMatchQuality(item.Text, query) diff --git a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs index e8a315f7..20b1391e 100644 --- a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs @@ -25,6 +25,8 @@ using Avalonia.Input; using Avalonia.Media; using AvaloniaEdit.Document; using AvaloniaEdit.Editing; +using AvaloniaEdit.Utils; +using StabilityMatrix.Avalonia.Models.TagCompletion; namespace StabilityMatrix.Avalonia.Controls.CodeCompletion; @@ -35,7 +37,8 @@ public class CompletionWindow : CompletionWindowBase { private PopupWithCustomPosition _toolTip; private ContentControl _toolTipContent; - + + private ICompletionProvider completionProvider; /// /// Gets the completion list used in this completion window. @@ -45,9 +48,15 @@ public class CompletionWindow : CompletionWindowBase /// /// Creates a new code completion window. /// - public CompletionWindow(TextArea textArea) : base(textArea) + public CompletionWindow(TextArea textArea, ICompletionProvider completionProvider) : base(textArea) { + this.completionProvider = completionProvider; + CompletionList = new CompletionList(); + + // For using our own UpdateQuery + CompletionList.IsFiltering = false; + // keep height automatic CloseAutomatically = true; MaxHeight = 225; @@ -249,10 +258,22 @@ public class CompletionWindow : CompletionWindowBase { var newText = document.GetText(StartOffset, offset - StartOffset); Debug.WriteLine("CaretPositionChanged newText: " + newText); - CompletionList.SelectItem(newText); - + // CompletionList.SelectItem(newText); + UpdateQuery(newText); + IsVisible = CompletionList.ListBox.ItemCount != 0; } } } + + /// + /// Update the completion window's current search term. + /// + public void UpdateQuery(string searchTerm) + { + var results = completionProvider.GetCompletions(searchTerm, 30, false); + CompletionList.CompletionData.Clear(); + CompletionList.CompletionData.AddRange(results); + CompletionList.SelectItem(searchTerm); + } } diff --git a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml index 423e2f2c..348d907a 100644 --- a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml +++ b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml @@ -68,7 +68,8 @@ Document="{Binding PromptDocument}" FontFamily="Cascadia Code,Consolas,Menlo,Monospace"> - + @@ -101,7 +102,8 @@ Document="{Binding NegativePromptDocument}" FontFamily="Cascadia Code,Consolas,Menlo,Monospace"> - + diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs index 11a35ca4..783f6110 100644 --- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs +++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs @@ -79,6 +79,7 @@ public static class DesignData services.AddLogging() .AddSingleton() .AddSingleton() + .AddSingleton() .AddSingleton() .AddSingleton(); diff --git a/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs b/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs new file mode 100644 index 00000000..1afa66e9 --- /dev/null +++ b/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs @@ -0,0 +1,70 @@ +using System.Collections.Generic; +using System.Data.Common; +using System.Globalization; +using System.IO; +using System.Threading.Tasks; +using StabilityMatrix.Avalonia.Models.TagCompletion; +using Sylvan.Data.Csv; +using Sylvan; +using Sylvan.Data; + +namespace StabilityMatrix.Avalonia.Helpers; + +public class TagCsvParser +{ + private readonly Stream stream; + + public TagCsvParser(Stream stream) + { + this.stream = stream; + } + + public async IAsyncEnumerable ParseAsync() + { + var pool = new StringPool(); + var options = new CsvDataReaderOptions + { + StringFactory = pool.GetString, + HasHeaders = false, + }; + + using var textReader = new StreamReader(stream); + await using var dataReader = await CsvDataReader.CreateAsync(textReader, options); + + while (await dataReader.ReadAsync()) + { + var entry = new TagCsvEntry + { + Name = dataReader.GetString(0), + Type = dataReader.GetInt32(1), + Count = dataReader.GetInt32(2), + Aliases = dataReader.GetString(3), + }; + yield return entry; + } + + /*var dataBinderOptions = new DataBinderOptions + { + BindingMode = DataBindingMode.Any + };*/ + /*var results = dataReader.GetRecordsAsync(dataBinderOptions); + return results;*/ + } + + public async Task> GetDictionaryAsync() + { + var dict = new Dictionary(); + + await foreach (var entry in ParseAsync()) + { + if (entry.Name is null || entry.Type is null) + { + continue; + } + + dict.Add(entry.Name, entry); + } + + return dict; + } +} diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs new file mode 100644 index 00000000..e3275a93 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using AutoComplete.Builders; +using AutoComplete.Clients.IndexSearchers; +using AutoComplete.DataStructure; +using AutoComplete.Domain; +using NLog; +using StabilityMatrix.Avalonia.Controls.CodeCompletion; +using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; + +namespace StabilityMatrix.Avalonia.Models.TagCompletion; + +public class CompletionProvider : ICompletionProvider +{ + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + private readonly Dictionary entries = new(); + + private InMemoryIndexSearcher? searcher; + + public bool IsLoaded => searcher is not null; + + public async Task LoadFromFile(FilePath path, bool recreate = false) + { + // Get Blake3 hash of file + var hash = await FileHash.GetBlake3Async(path); + + Logger.Trace("Loading tags from {Path} with Blake3 hash {Hash}", path, hash); + + // Check for AppData/StabilityMatrix/Temp/Tags//*.bin + var tempTagsDir = GlobalConfig.HomeDir.JoinDir("Temp", "Tags"); + tempTagsDir.Create(); + var hashDir = tempTagsDir.JoinDir(hash); + + var headerFile = hashDir.JoinFile("header.bin"); + var indexFile = hashDir.JoinFile("index.bin"); + var tailFile = hashDir.JoinFile("tail.bin"); + + entries.Clear(); + + // If directory or any file is missing, rebuild the index + if (recreate || !(hashDir.Exists && headerFile.Exists && indexFile.Exists && tailFile.Exists)) + { + Logger.Trace("Creating new index for {Path}", hashDir); + hashDir.Create(); + + await using var headerStream = headerFile.Info.OpenWrite(); + await using var indexStream = indexFile.Info.OpenWrite(); + await using var tailStream = tailFile.Info.OpenWrite(); + + var builder = new IndexBuilder(headerStream, indexStream, tailStream); + + // Parse csv + var csvStream = path.Info.OpenRead(); + var parser = new TagCsvParser(csvStream); + + await foreach (var entry in parser.ParseAsync()) + { + if (string.IsNullOrWhiteSpace(entry.Name)) + { + continue; + } + // Add to index + builder.Add(entry.Name); + // Add to local dictionary + entries.Add(entry.Name, entry); + } + + builder.Build(); + } + + searcher = new InMemoryIndexSearcher(headerFile, indexFile, tailFile); + searcher.Init(); + } + + /// + public IEnumerable GetCompletions(string searchTerm, int itemsCount, bool suggest) + { + if (searcher is null) + { + throw new InvalidOperationException("Index is not loaded"); + } + + var searchOptions = new SearchOptions + { + Term = searchTerm, + MaxItemCount = itemsCount, + SuggestWhenFoundStartsWith = suggest + }; + + var result = searcher.Search(searchOptions); + + // No results + if (result.ResultType == TrieNodeSearchResultType.NotFound) + { + Logger.Trace("No results for {Term}", searchTerm); + return Array.Empty(); + } + + Logger.Trace("Got {Count} results for {Term}", result.Items.Length, searchTerm); + + // Get entry for each result + var completions = new List(); + foreach (var item in result.Items) + { + if (entries.TryGetValue(item, out var entry)) + { + var entryType = TagTypeExtensions.FromE621(entry.Type.GetValueOrDefault(-1)); + completions.Add(new TagCompletionData(entry.Name!, entryType)); + } + } + + return completions; + } +} diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs new file mode 100644 index 00000000..4ad77f5e --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs @@ -0,0 +1,28 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using StabilityMatrix.Avalonia.Controls.CodeCompletion; +using StabilityMatrix.Core.Models.FileInterfaces; + +namespace StabilityMatrix.Avalonia.Models.TagCompletion; + +public interface ICompletionProvider +{ + /// + /// Whether the completion provider is loaded. + /// + bool IsLoaded { get; } + + /// + /// Load the completion provider from a file. + /// + Task LoadFromFile(FilePath path, bool recreate = false); + + /// + /// Returns a list of completion items for the given text. + /// + public IEnumerable GetCompletions( + string searchTerm, + int itemsCount, + bool suggest + ); +} diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj index 932faa2e..61adfd90 100644 --- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj +++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj @@ -15,6 +15,7 @@ + @@ -48,6 +49,9 @@ + + + diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs index 95f1c13c..2e368413 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs @@ -1,21 +1,25 @@ using System.Text.Json.Nodes; using AvaloniaEdit.Document; -using CommunityToolkit.Mvvm.ComponentModel; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(PromptCard))] -public partial class PromptCardViewModel : LoadableViewModelBase +public class PromptCardViewModel : LoadableViewModelBase { + public ICompletionProvider CompletionProvider { get; } + public TextDocument PromptDocument { get; } = new(); public TextDocument NegativePromptDocument { get; } = new(); - [ObservableProperty] private int editorFontSize = 14; - - [ObservableProperty] private string editorFontFamily = "Consolas"; + /// + public PromptCardViewModel(ICompletionProvider completionProvider) + { + CompletionProvider = completionProvider; + } /// public override JsonObject SaveStateToJsonObject() diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs index 2e2e485f..c888dbd4 100644 --- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs @@ -24,6 +24,7 @@ using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.Views; @@ -50,6 +51,7 @@ public partial class SettingsViewModel : PageViewModelBase private readonly IPrerequisiteHelper prerequisiteHelper; private readonly IPyRunner pyRunner; private readonly ServiceManager dialogFactory; + private readonly ICompletionProvider completionProvider; public SharedState SharedState { get; } @@ -93,13 +95,15 @@ public partial class SettingsViewModel : PageViewModelBase IPrerequisiteHelper prerequisiteHelper, IPyRunner pyRunner, ServiceManager dialogFactory, - SharedState sharedState) + SharedState sharedState, + ICompletionProvider completionProvider) { this.notificationService = notificationService; this.settingsManager = settingsManager; this.prerequisiteHelper = prerequisiteHelper; this.pyRunner = pyRunner; this.dialogFactory = dialogFactory; + this.completionProvider = completionProvider; SharedState = sharedState; @@ -436,6 +440,19 @@ public partial class SettingsViewModel : PageViewModelBase await dialog.ShowAsync(); } + + [RelayCommand] + private async Task DebugLoadCompletionCsv() + { + var provider = App.StorageProvider; + var files = await provider.OpenFilePickerAsync(new FilePickerOpenOptions()); + + if (files.Count == 0) return; + + await completionProvider.LoadFromFile(files[0].TryGetLocalPath()!, true); + + notificationService.Show("Loaded completion file", ""); + } #endregion #region Info Section diff --git a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml index f75923af..63a5545c 100644 --- a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml +++ b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml @@ -252,6 +252,14 @@ + + +