Browse Source

Add CompletionProvider, debug option for loading

pull/165/head
Ionite 1 year ago
parent
commit
7fc4b7b4eb
No known key found for this signature in database
  1. 3
      StabilityMatrix.Avalonia/App.axaml.cs
  2. 24
      StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs
  3. 7
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs
  4. 29
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs
  5. 6
      StabilityMatrix.Avalonia/Controls/PromptCard.axaml
  6. 1
      StabilityMatrix.Avalonia/DesignData/DesignData.cs
  7. 70
      StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs
  8. 119
      StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs
  9. 28
      StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs
  10. 4
      StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
  11. 14
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  12. 19
      StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
  13. 8
      StabilityMatrix.Avalonia/Views/SettingsPage.axaml

3
StabilityMatrix.Avalonia/App.axaml.cs

@ -32,9 +32,11 @@ using Polly.Timeout;
using Refit;
using Sentry;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.DesignData;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
@ -365,6 +367,7 @@ public sealed class App : Application
services.AddSingleton<IPyRunner, PyRunner>();
services.AddSingleton<IUpdateHelper, UpdateHelper>();
services.AddSingleton<IInferenceClientManager, InferenceClientManager>();
services.AddSingleton<ICompletionProvider, CompletionProvider>();
// Rich presence
services.AddSingleton<IDiscordRichPresenceService, DiscordRichPresenceService>();

24
StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs

@ -9,6 +9,7 @@ using AvaloniaEdit.Editing;
using NLog;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using CompletionWindow = StabilityMatrix.Avalonia.Controls.CodeCompletion.CompletionWindow;
namespace StabilityMatrix.Avalonia.Behaviors;
@ -22,13 +23,13 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
private CompletionWindow? completionWindow;
// ReSharper disable once MemberCanBePrivate.Global
public static readonly StyledProperty<string> TextProperty =
AvaloniaProperty.Register<TextEditorCompletionBehavior, string>(nameof(Text));
public static readonly StyledProperty<ICompletionProvider> CompletionProviderProperty =
AvaloniaProperty.Register<TextEditorCompletionBehavior, ICompletionProvider>(nameof(CompletionProvider));
public string Text
public ICompletionProvider CompletionProvider
{
get => GetValue(TextProperty);
set => SetValue(TextProperty, value);
get => GetValue(CompletionProviderProperty);
set => SetValue(CompletionProviderProperty, value);
}
protected override void OnAttached()
@ -55,7 +56,7 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
private CompletionWindow CreateCompletionWindow(TextArea textArea)
{
var window = new CompletionWindow(textArea)
var window = new CompletionWindow(textArea, CompletionProvider)
{
WindowManagerAddShadowHint = false,
CloseWhenCaretAtBeginning = true,
@ -66,13 +67,6 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
IsFiltering = true
}
};
var completionList = window.CompletionList;
completionList.CompletionData.Add(new CompletionData("item1"));
completionList.CompletionData.Add(new CompletionData("item2"));
completionList.CompletionData.Add(new CompletionData("item3"));
return window;
}
@ -98,8 +92,8 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
completionWindow = CreateCompletionWindow(textEditor.TextArea);
completionWindow.StartOffset = tokenSegment.Offset;
completionWindow.EndOffset = tokenSegment.EndOffset;
completionWindow.CompletionList.SelectItem(token);
completionWindow.UpdateQuery(token);
completionWindow.Closed += delegate
{

7
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs

@ -284,14 +284,15 @@ public class CompletionList : TemplatedControl
private void SelectItemFiltering(string query)
{
// if the user just typed one more character, don't filter all data but just filter what we are already displaying
var listToFilter =
/*var listToFilter =
_currentList != null
&& !string.IsNullOrEmpty(_currentText)
&& !string.IsNullOrEmpty(query)
&& query.StartsWith(_currentText, StringComparison.Ordinal)
? _currentList
: _completionData;
: _completionData;*/
var listToFilter = _completionData;
var matchingItems =
from item in listToFilter
let quality = GetMatchQuality(item.Text, query)

29
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs

@ -25,6 +25,8 @@ using Avalonia.Input;
using Avalonia.Media;
using AvaloniaEdit.Document;
using AvaloniaEdit.Editing;
using AvaloniaEdit.Utils;
using StabilityMatrix.Avalonia.Models.TagCompletion;
namespace StabilityMatrix.Avalonia.Controls.CodeCompletion;
@ -35,7 +37,8 @@ public class CompletionWindow : CompletionWindowBase
{
private PopupWithCustomPosition _toolTip;
private ContentControl _toolTipContent;
private ICompletionProvider completionProvider;
/// <summary>
/// Gets the completion list used in this completion window.
@ -45,9 +48,15 @@ public class CompletionWindow : CompletionWindowBase
/// <summary>
/// Creates a new code completion window.
/// </summary>
public CompletionWindow(TextArea textArea) : base(textArea)
public CompletionWindow(TextArea textArea, ICompletionProvider completionProvider) : base(textArea)
{
this.completionProvider = completionProvider;
CompletionList = new CompletionList();
// For using our own UpdateQuery
CompletionList.IsFiltering = false;
// keep height automatic
CloseAutomatically = true;
MaxHeight = 225;
@ -249,10 +258,22 @@ public class CompletionWindow : CompletionWindowBase
{
var newText = document.GetText(StartOffset, offset - StartOffset);
Debug.WriteLine("CaretPositionChanged newText: " + newText);
CompletionList.SelectItem(newText);
// CompletionList.SelectItem(newText);
UpdateQuery(newText);
IsVisible = CompletionList.ListBox.ItemCount != 0;
}
}
}
/// <summary>
/// Update the completion window's current search term.
/// </summary>
public void UpdateQuery(string searchTerm)
{
var results = completionProvider.GetCompletions(searchTerm, 30, false);
CompletionList.CompletionData.Clear();
CompletionList.CompletionData.AddRange(results);
CompletionList.SelectItem(searchTerm);
}
}

6
StabilityMatrix.Avalonia/Controls/PromptCard.axaml

@ -68,7 +68,8 @@
Document="{Binding PromptDocument}"
FontFamily="Cascadia Code,Consolas,Menlo,Monospace">
<i:Interaction.Behaviors>
<behaviors:TextEditorCompletionBehavior/>
<behaviors:TextEditorCompletionBehavior
CompletionProvider="{Binding CompletionProvider}"/>
</i:Interaction.Behaviors>
</avaloniaEdit:TextEditor>
@ -101,7 +102,8 @@
Document="{Binding NegativePromptDocument}"
FontFamily="Cascadia Code,Consolas,Menlo,Monospace">
<i:Interaction.Behaviors>
<behaviors:TextEditorCompletionBehavior/>
<behaviors:TextEditorCompletionBehavior
CompletionProvider="{Binding CompletionProvider}"/>
</i:Interaction.Behaviors>
</avaloniaEdit:TextEditor>

1
StabilityMatrix.Avalonia/DesignData/DesignData.cs

@ -79,6 +79,7 @@ public static class DesignData
services.AddLogging()
.AddSingleton<IPackageFactory, PackageFactory>()
.AddSingleton<IUpdateHelper, UpdateHelper>()
.AddSingleton<ICompletionProvider, CompletionProvider>()
.AddSingleton<ModelFinder>()
.AddSingleton<SharedState>();

70
StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs

@ -0,0 +1,70 @@
using System.Collections.Generic;
using System.Data.Common;
using System.Globalization;
using System.IO;
using System.Threading.Tasks;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using Sylvan.Data.Csv;
using Sylvan;
using Sylvan.Data;
namespace StabilityMatrix.Avalonia.Helpers;
public class TagCsvParser
{
private readonly Stream stream;
public TagCsvParser(Stream stream)
{
this.stream = stream;
}
public async IAsyncEnumerable<TagCsvEntry> ParseAsync()
{
var pool = new StringPool();
var options = new CsvDataReaderOptions
{
StringFactory = pool.GetString,
HasHeaders = false,
};
using var textReader = new StreamReader(stream);
await using var dataReader = await CsvDataReader.CreateAsync(textReader, options);
while (await dataReader.ReadAsync())
{
var entry = new TagCsvEntry
{
Name = dataReader.GetString(0),
Type = dataReader.GetInt32(1),
Count = dataReader.GetInt32(2),
Aliases = dataReader.GetString(3),
};
yield return entry;
}
/*var dataBinderOptions = new DataBinderOptions
{
BindingMode = DataBindingMode.Any
};*/
/*var results = dataReader.GetRecordsAsync<TagCsvEntry>(dataBinderOptions);
return results;*/
}
public async Task<Dictionary<string, TagCsvEntry>> GetDictionaryAsync()
{
var dict = new Dictionary<string, TagCsvEntry>();
await foreach (var entry in ParseAsync())
{
if (entry.Name is null || entry.Type is null)
{
continue;
}
dict.Add(entry.Name, entry);
}
return dict;
}
}

119
StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs

@ -0,0 +1,119 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using AutoComplete.Builders;
using AutoComplete.Clients.IndexSearchers;
using AutoComplete.DataStructure;
using AutoComplete.Domain;
using NLog;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
public class CompletionProvider : ICompletionProvider
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly Dictionary<string, TagCsvEntry> entries = new();
private InMemoryIndexSearcher? searcher;
public bool IsLoaded => searcher is not null;
public async Task LoadFromFile(FilePath path, bool recreate = false)
{
// Get Blake3 hash of file
var hash = await FileHash.GetBlake3Async(path);
Logger.Trace("Loading tags from {Path} with Blake3 hash {Hash}", path, hash);
// Check for AppData/StabilityMatrix/Temp/Tags/<hash>/*.bin
var tempTagsDir = GlobalConfig.HomeDir.JoinDir("Temp", "Tags");
tempTagsDir.Create();
var hashDir = tempTagsDir.JoinDir(hash);
var headerFile = hashDir.JoinFile("header.bin");
var indexFile = hashDir.JoinFile("index.bin");
var tailFile = hashDir.JoinFile("tail.bin");
entries.Clear();
// If directory or any file is missing, rebuild the index
if (recreate || !(hashDir.Exists && headerFile.Exists && indexFile.Exists && tailFile.Exists))
{
Logger.Trace("Creating new index for {Path}", hashDir);
hashDir.Create();
await using var headerStream = headerFile.Info.OpenWrite();
await using var indexStream = indexFile.Info.OpenWrite();
await using var tailStream = tailFile.Info.OpenWrite();
var builder = new IndexBuilder(headerStream, indexStream, tailStream);
// Parse csv
var csvStream = path.Info.OpenRead();
var parser = new TagCsvParser(csvStream);
await foreach (var entry in parser.ParseAsync())
{
if (string.IsNullOrWhiteSpace(entry.Name))
{
continue;
}
// Add to index
builder.Add(entry.Name);
// Add to local dictionary
entries.Add(entry.Name, entry);
}
builder.Build();
}
searcher = new InMemoryIndexSearcher(headerFile, indexFile, tailFile);
searcher.Init();
}
/// <inheritdoc />
public IEnumerable<ICompletionData> GetCompletions(string searchTerm, int itemsCount, bool suggest)
{
if (searcher is null)
{
throw new InvalidOperationException("Index is not loaded");
}
var searchOptions = new SearchOptions
{
Term = searchTerm,
MaxItemCount = itemsCount,
SuggestWhenFoundStartsWith = suggest
};
var result = searcher.Search(searchOptions);
// No results
if (result.ResultType == TrieNodeSearchResultType.NotFound)
{
Logger.Trace("No results for {Term}", searchTerm);
return Array.Empty<ICompletionData>();
}
Logger.Trace("Got {Count} results for {Term}", result.Items.Length, searchTerm);
// Get entry for each result
var completions = new List<ICompletionData>();
foreach (var item in result.Items)
{
if (entries.TryGetValue(item, out var entry))
{
var entryType = TagTypeExtensions.FromE621(entry.Type.GetValueOrDefault(-1));
completions.Add(new TagCompletionData(entry.Name!, entryType));
}
}
return completions;
}
}

28
StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs

@ -0,0 +1,28 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
public interface ICompletionProvider
{
/// <summary>
/// Whether the completion provider is loaded.
/// </summary>
bool IsLoaded { get; }
/// <summary>
/// Load the completion provider from a file.
/// </summary>
Task LoadFromFile(FilePath path, bool recreate = false);
/// <summary>
/// Returns a list of completion items for the given text.
/// </summary>
public IEnumerable<ICompletionData> GetCompletions(
string searchTerm,
int itemsCount,
bool suggest
);
}

4
StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

@ -15,6 +15,7 @@
<ItemGroup>
<PackageReference Include="AsyncImageLoader.Avalonia" Version="3.0.0" />
<PackageReference Include="AutoComplete.Net" Version="1.2211.2014.42"/>
<PackageReference Include="Avalonia" Version="11.0.2" />
<PackageReference Include="Avalonia.AvaloniaEdit" Version="11.0.0" />
<PackageReference Include="Avalonia.Desktop" Version="11.0.2" />
@ -48,6 +49,9 @@
<PackageReference Include="RockLib.Reflection.Optimized" Version="2.0.0" />
<PackageReference Include="Sentry" Version="3.33.1" />
<PackageReference Include="Sentry.NLog" Version="3.33.1" />
<PackageReference Include="Sylvan.Common" Version="0.4.2" />
<PackageReference Include="Sylvan.Data" Version="0.2.12" />
<PackageReference Include="Sylvan.Data.Csv" Version="1.3.3" />
<PackageReference Include="TextMateSharp.Grammars" Version="1.0.55" />
</ItemGroup>

14
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,21 +1,25 @@
using System.Text.Json.Nodes;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))]
public partial class PromptCardViewModel : LoadableViewModelBase
public class PromptCardViewModel : LoadableViewModelBase
{
public ICompletionProvider CompletionProvider { get; }
public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new();
[ObservableProperty] private int editorFontSize = 14;
[ObservableProperty] private string editorFontFamily = "Consolas";
/// <inheritdoc />
public PromptCardViewModel(ICompletionProvider completionProvider)
{
CompletionProvider = completionProvider;
}
/// <inheritdoc />
public override JsonObject SaveStateToJsonObject()

19
StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs

@ -24,6 +24,7 @@ using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.Views;
@ -50,6 +51,7 @@ public partial class SettingsViewModel : PageViewModelBase
private readonly IPrerequisiteHelper prerequisiteHelper;
private readonly IPyRunner pyRunner;
private readonly ServiceManager<ViewModelBase> dialogFactory;
private readonly ICompletionProvider completionProvider;
public SharedState SharedState { get; }
@ -93,13 +95,15 @@ public partial class SettingsViewModel : PageViewModelBase
IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner,
ServiceManager<ViewModelBase> dialogFactory,
SharedState sharedState)
SharedState sharedState,
ICompletionProvider completionProvider)
{
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.prerequisiteHelper = prerequisiteHelper;
this.pyRunner = pyRunner;
this.dialogFactory = dialogFactory;
this.completionProvider = completionProvider;
SharedState = sharedState;
@ -436,6 +440,19 @@ public partial class SettingsViewModel : PageViewModelBase
await dialog.ShowAsync();
}
[RelayCommand]
private async Task DebugLoadCompletionCsv()
{
var provider = App.StorageProvider;
var files = await provider.OpenFilePickerAsync(new FilePickerOpenOptions());
if (files.Count == 0) return;
await completionProvider.LoadFromFile(files[0].TryGetLocalPath()!, true);
notificationService.Show("Loaded completion file", "");
}
#endregion
#region Info Section

8
StabilityMatrix.Avalonia/Views/SettingsPage.axaml

@ -252,6 +252,14 @@
</ui:SettingsExpanderItem.Footer>
</ui:SettingsExpanderItem>
<ui:SettingsExpanderItem Content="Load Completion Source" IconSource="ImageCopy">
<ui:SettingsExpanderItem.Footer>
<Button
Command="{Binding DebugLoadCompletionCsvCommand}"
Content="Load CSV File"/>
</ui:SettingsExpanderItem.Footer>
</ui:SettingsExpanderItem>
</ui:SettingsExpander>
</Grid>

Loading…
Cancel
Save