|
|
|
@ -14,7 +14,9 @@ using Nito.AsyncEx;
|
|
|
|
|
using NLog; |
|
|
|
|
using StabilityMatrix.Avalonia.Controls.CodeCompletion; |
|
|
|
|
using StabilityMatrix.Avalonia.Helpers; |
|
|
|
|
using StabilityMatrix.Avalonia.Models.Inference.Tokens; |
|
|
|
|
using StabilityMatrix.Avalonia.Services; |
|
|
|
|
using StabilityMatrix.Core.Extensions; |
|
|
|
|
using StabilityMatrix.Core.Helper; |
|
|
|
|
using StabilityMatrix.Core.Models; |
|
|
|
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
|
|
@ -29,26 +31,33 @@ public class CompletionProvider : ICompletionProvider
|
|
|
|
|
|
|
|
|
|
private readonly ISettingsManager settingsManager; |
|
|
|
|
private readonly INotificationService notificationService; |
|
|
|
|
|
|
|
|
|
private readonly IModelIndexService modelIndexService; |
|
|
|
|
|
|
|
|
|
private readonly AsyncLock loadLock = new(); |
|
|
|
|
private readonly Dictionary<string, TagCsvEntry> entries = new(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private InMemoryIndexSearcher? searcher; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public bool IsLoaded => searcher is not null; |
|
|
|
|
|
|
|
|
|
public Func<string, string>? PrepareInsertionText |
|
|
|
|
=> settingsManager.Settings.IsCompletionRemoveUnderscoresEnabled |
|
|
|
|
? PrepareInsertionText_RemoveUnderscores : null; |
|
|
|
|
|
|
|
|
|
public CompletionProvider(ISettingsManager settingsManager, INotificationService notificationService) |
|
|
|
|
public Func<string, string>? PrepareInsertionText => |
|
|
|
|
settingsManager.Settings.IsCompletionRemoveUnderscoresEnabled |
|
|
|
|
? PrepareInsertionText_RemoveUnderscores |
|
|
|
|
: null; |
|
|
|
|
|
|
|
|
|
public CompletionProvider( |
|
|
|
|
ISettingsManager settingsManager, |
|
|
|
|
INotificationService notificationService, |
|
|
|
|
IModelIndexService modelIndexService |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
this.settingsManager = settingsManager; |
|
|
|
|
this.notificationService = notificationService; |
|
|
|
|
this.modelIndexService = modelIndexService; |
|
|
|
|
|
|
|
|
|
// Attach to load from set file on initial settings load |
|
|
|
|
settingsManager.Loaded += (_, _) => UpdateTagCompletionCsv(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Also load when TagCompletionCsv property changes |
|
|
|
|
settingsManager.SettingsPropertyChanged += (_, args) => |
|
|
|
|
{ |
|
|
|
@ -57,25 +66,26 @@ public class CompletionProvider : ICompletionProvider
|
|
|
|
|
UpdateTagCompletionCsv(); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If library already loaded, start a background load |
|
|
|
|
if (settingsManager.IsLibraryDirSet) |
|
|
|
|
{ |
|
|
|
|
UpdateTagCompletionCsv(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
void UpdateTagCompletionCsv() |
|
|
|
|
{ |
|
|
|
|
var csvPath = settingsManager.Settings.TagCompletionCsv; |
|
|
|
|
if (csvPath is null) return; |
|
|
|
|
if (csvPath is null) |
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
var fullPath = settingsManager.TagsDirectory.JoinFile(csvPath); |
|
|
|
|
BackgroundLoadFromFile(fullPath); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static string PrepareInsertionText_RemoveUnderscores(string text) |
|
|
|
|
{ |
|
|
|
|
return text.Replace("_", " "); |
|
|
|
@ -84,62 +94,70 @@ public class CompletionProvider : ICompletionProvider
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
|
public void BackgroundLoadFromFile(FilePath path, bool recreate = false) |
|
|
|
|
{ |
|
|
|
|
LoadFromFile(path, recreate).SafeFireAndForget(onException: exception => |
|
|
|
|
{ |
|
|
|
|
const string title = "Failed to load tag completion file"; |
|
|
|
|
Debug.Fail(title); |
|
|
|
|
Logger.Warn(exception, title); |
|
|
|
|
notificationService.Show(title + $" {path.Name}", |
|
|
|
|
exception.Message, NotificationType.Error); |
|
|
|
|
}, true); |
|
|
|
|
LoadFromFile(path, recreate) |
|
|
|
|
.SafeFireAndForget( |
|
|
|
|
onException: exception => |
|
|
|
|
{ |
|
|
|
|
const string title = "Failed to load tag completion file"; |
|
|
|
|
Debug.Fail(title); |
|
|
|
|
Logger.Warn(exception, title); |
|
|
|
|
notificationService.Show( |
|
|
|
|
title + $" {path.Name}", |
|
|
|
|
exception.Message, |
|
|
|
|
NotificationType.Error |
|
|
|
|
); |
|
|
|
|
}, |
|
|
|
|
true |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
|
public async Task LoadFromFile(FilePath path, bool recreate = false) |
|
|
|
|
{ |
|
|
|
|
using var _ = await loadLock.LockAsync(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Get Blake3 hash of file |
|
|
|
|
var hash = await FileHash.GetBlake3Async(path); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Logger.Trace("Loading tags from {Path} with Blake3 hash {Hash}", path, hash); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Check for AppData/StabilityMatrix/Temp/Tags/<hash>/*.bin |
|
|
|
|
var tempTagsDir = GlobalConfig.HomeDir.JoinDir("Temp", "Tags"); |
|
|
|
|
var hashDir = tempTagsDir.JoinDir(hash); |
|
|
|
|
hashDir.Create(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var headerFile = hashDir.JoinFile("header.bin"); |
|
|
|
|
var indexFile = hashDir.JoinFile("index.bin"); |
|
|
|
|
|
|
|
|
|
entries.Clear(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var timer = Stopwatch.StartNew(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If directory or any file is missing, rebuild the index |
|
|
|
|
if (recreate || !(hashDir.Exists && headerFile.Exists && indexFile.Exists)) |
|
|
|
|
{ |
|
|
|
|
Logger.Debug("Creating new index for {Path}", hashDir); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await using var headerStream = headerFile.Info.OpenWrite(); |
|
|
|
|
await using var indexStream = indexFile.Info.OpenWrite(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var builder = new IndexBuilder(headerStream, indexStream); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Parse csv |
|
|
|
|
await using var csvStream = path.Info.OpenRead(); |
|
|
|
|
var parser = new TagCsvParser(csvStream); |
|
|
|
|
|
|
|
|
|
await foreach (var entry in parser.ParseAsync()) |
|
|
|
|
{ |
|
|
|
|
if (string.IsNullOrWhiteSpace(entry.Name)) continue; |
|
|
|
|
|
|
|
|
|
if (string.IsNullOrWhiteSpace(entry.Name)) |
|
|
|
|
continue; |
|
|
|
|
|
|
|
|
|
// Add to index |
|
|
|
|
builder.Add(entry.Name); |
|
|
|
|
// Add to local dictionary |
|
|
|
|
entries.Add(entry.Name, entry); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await Task.Run(builder.Build); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
@ -149,82 +167,104 @@ public class CompletionProvider : ICompletionProvider
|
|
|
|
|
|
|
|
|
|
await using var csvStream = path.Info.OpenRead(); |
|
|
|
|
var parser = new TagCsvParser(csvStream); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await foreach (var entry in parser.ParseAsync()) |
|
|
|
|
{ |
|
|
|
|
if (string.IsNullOrWhiteSpace(entry.Name)) continue; |
|
|
|
|
|
|
|
|
|
if (string.IsNullOrWhiteSpace(entry.Name)) |
|
|
|
|
continue; |
|
|
|
|
|
|
|
|
|
// Add to local dictionary |
|
|
|
|
entries.Add(entry.Name, entry); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
searcher = new InMemoryIndexSearcher(headerFile, indexFile); |
|
|
|
|
searcher.Init(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var elapsed = timer.Elapsed; |
|
|
|
|
|
|
|
|
|
Logger.Info("Loaded {Count} tags for {Path} in {Time:F2}s", entries.Count, path.Name, elapsed.TotalSeconds); |
|
|
|
|
|
|
|
|
|
Logger.Info( |
|
|
|
|
"Loaded {Count} tags for {Path} in {Time:F2}s", |
|
|
|
|
entries.Count, |
|
|
|
|
path.Name, |
|
|
|
|
elapsed.TotalSeconds |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
|
public IEnumerable<ICompletionData> GetCompletions(string searchTerm, int itemsCount, bool suggest) |
|
|
|
|
public IEnumerable<ICompletionData> GetCompletions( |
|
|
|
|
TextCompletionRequest completionRequest, |
|
|
|
|
int itemsCount, |
|
|
|
|
bool suggest |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
return GetCompletionsImpl_Index(searchTerm, itemsCount, suggest); |
|
|
|
|
} |
|
|
|
|
if (completionRequest.Type == CompletionType.Tag) |
|
|
|
|
{ |
|
|
|
|
return GetCompletionTags(completionRequest.Text, itemsCount, suggest); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private IEnumerable<ICompletionData> GetCompletionsImpl_Fuzzy(string searchTerm, int itemsCount, bool suggest) |
|
|
|
|
{ |
|
|
|
|
var extracted = FuzzySharp.Process |
|
|
|
|
.ExtractTop(searchTerm, entries.Keys); |
|
|
|
|
|
|
|
|
|
var results = extracted |
|
|
|
|
.Where(r => r.Score > 40) |
|
|
|
|
.Select(r => r.Value) |
|
|
|
|
.ToImmutableArray(); |
|
|
|
|
|
|
|
|
|
// No results |
|
|
|
|
if (results.IsEmpty) |
|
|
|
|
if (completionRequest.Type == CompletionType.ExtraNetwork) |
|
|
|
|
{ |
|
|
|
|
Logger.Trace("No results for {Term}", searchTerm); |
|
|
|
|
return Array.Empty<ICompletionData>(); |
|
|
|
|
return GetCompletionNetworks( |
|
|
|
|
completionRequest.ExtraNetworkTypes, |
|
|
|
|
completionRequest.Text, |
|
|
|
|
itemsCount |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Logger.Trace("Got {Count} results for {Term}", results.Length, searchTerm); |
|
|
|
|
|
|
|
|
|
// Get entry for each result |
|
|
|
|
|
|
|
|
|
throw new InvalidOperationException(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private IEnumerable<ICompletionData> GetCompletionNetworks( |
|
|
|
|
PromptExtraNetworkType networkType, |
|
|
|
|
string searchTerm, |
|
|
|
|
int itemsCount |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
var folderTypes = Enum.GetValues(typeof(PromptExtraNetworkType)) |
|
|
|
|
.Cast<PromptExtraNetworkType>() |
|
|
|
|
.Where(f => networkType.HasFlag(f)) |
|
|
|
|
.Select(network => network.ConvertTo<SharedFolderType>()); |
|
|
|
|
|
|
|
|
|
var completions = new List<ICompletionData>(); |
|
|
|
|
foreach (var item in results) |
|
|
|
|
|
|
|
|
|
foreach (var folderType in folderTypes) |
|
|
|
|
{ |
|
|
|
|
if (entries.TryGetValue(item, out var entry)) |
|
|
|
|
// Get from index service |
|
|
|
|
if (modelIndexService.ModelIndex.TryGetValue(folderType, out var localModels)) |
|
|
|
|
{ |
|
|
|
|
var entryType = TagTypeExtensions.FromE621(entry.Type.GetValueOrDefault(-1)); |
|
|
|
|
completions.Add(new TagCompletionData(entry.Name!, entryType) |
|
|
|
|
{ |
|
|
|
|
Priority = entry.Count ?? 0 |
|
|
|
|
}); |
|
|
|
|
var results = |
|
|
|
|
from model in localModels |
|
|
|
|
where model.FileName.StartsWith(searchTerm, StringComparison.OrdinalIgnoreCase) |
|
|
|
|
select ModelCompletionData.FromLocalModel(model, networkType); |
|
|
|
|
|
|
|
|
|
completions.AddRange(results.Take(itemsCount)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return completions; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private IEnumerable<ICompletionData> GetCompletionsImpl_Index(string searchTerm, int itemsCount, bool suggest) |
|
|
|
|
|
|
|
|
|
private IEnumerable<ICompletionData> GetCompletionTags( |
|
|
|
|
string searchTerm, |
|
|
|
|
int itemsCount, |
|
|
|
|
bool suggest |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
if (searcher is null) |
|
|
|
|
{ |
|
|
|
|
throw new InvalidOperationException("Index is not loaded"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var timer = Stopwatch.StartNew(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var searchOptions = new SearchOptions |
|
|
|
|
{ |
|
|
|
|
Term = searchTerm, |
|
|
|
|
MaxItemCount = itemsCount, |
|
|
|
|
SuggestWhenFoundStartsWith = suggest |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var result = searcher.Search(searchOptions); |
|
|
|
|
|
|
|
|
|
// No results |
|
|
|
@ -233,16 +273,16 @@ public class CompletionProvider : ICompletionProvider
|
|
|
|
|
Logger.Trace("No results for {Term}", searchTerm); |
|
|
|
|
return Array.Empty<ICompletionData>(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Is null for some reason? |
|
|
|
|
if (result.Items is null) |
|
|
|
|
{ |
|
|
|
|
Logger.Warn("Got null results for {Term}", searchTerm); |
|
|
|
|
return Array.Empty<ICompletionData>(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Logger.Trace("Got {Count} results for {Term}", result.Items.Length, searchTerm); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Get entry for each result |
|
|
|
|
var completions = new List<ICompletionData>(); |
|
|
|
|
foreach (var item in result.Items) |
|
|
|
@ -253,10 +293,14 @@ public class CompletionProvider : ICompletionProvider
|
|
|
|
|
completions.Add(new TagCompletionData(entry.Name!, entryType)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timer.Stop(); |
|
|
|
|
Logger.Trace("Completions for {Term} took {Time:F2}ms", searchTerm, timer.Elapsed.TotalMilliseconds); |
|
|
|
|
|
|
|
|
|
Logger.Trace( |
|
|
|
|
"Completions for {Term} took {Time:F2}ms", |
|
|
|
|
searchTerm, |
|
|
|
|
timer.Elapsed.TotalMilliseconds |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
return completions; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|