diff --git a/StabilityMatrix.Avalonia/Assets.cs b/StabilityMatrix.Avalonia/Assets.cs index cf8a0931..db08322e 100644 --- a/StabilityMatrix.Avalonia/Assets.cs +++ b/StabilityMatrix.Avalonia/Assets.cs @@ -32,7 +32,7 @@ internal static class Assets public static AvaloniaResource ThemeMatrixDarkJson => new("avares://StabilityMatrix.Avalonia/Assets/ThemeMatrixDark.json"); - private const UnixFileMode unix755 = + private const UnixFileMode Unix755 = UnixFileMode.UserRead | UnixFileMode.UserWrite | UnixFileMode.UserExecute @@ -54,14 +54,14 @@ internal static class Assets PlatformKind.Linux | PlatformKind.X64, new AvaloniaResource( "avares://StabilityMatrix.Avalonia/Assets/linux-x64/7zzs", - unix755 + Unix755 ) ), ( PlatformKind.MacOS | PlatformKind.Arm, new AvaloniaResource( "avares://StabilityMatrix.Avalonia/Assets/macos-arm64/7zz", - unix755 + Unix755 ) ) ); @@ -136,6 +136,19 @@ internal static class Assets ) ); + public static IReadOnlyList DefaultCompletionTags { get; } = + new[] + { + new RemoteResource( + new Uri("https://cdn.lykos.ai/tags/danbooru.csv"), + "b84a879f1d9c47bf4758d66542598faa565b1571122ae12e7b145da8e7a4c1c6" + ), + new RemoteResource( + new Uri("https://cdn.lykos.ai/tags/e621.csv"), + "ef7ea148ad865ad936d0c1ee57f0f83de723b43056c70b07fd67dbdbb89cae35" + ) + }; + public static Uri DiscordServerUrl { get; } = new("https://discord.com/invite/TUrgfECxHz"); public static Uri PatreonUrl { get; } = new("https://patreon.com/StabilityMatrix"); diff --git a/StabilityMatrix.Avalonia/DesignData/MockCompletionProvider.cs b/StabilityMatrix.Avalonia/DesignData/MockCompletionProvider.cs index d6702f2d..1b22b1eb 100644 --- a/StabilityMatrix.Avalonia/DesignData/MockCompletionProvider.cs +++ b/StabilityMatrix.Avalonia/DesignData/MockCompletionProvider.cs @@ -15,6 +15,12 @@ public class MockCompletionProvider : ICompletionProvider /// public Func? PrepareInsertionText { get; } = data => data.Text; + /// + public Task Setup() + { + return Task.CompletedTask; + } + /// public Task LoadFromFile(FilePath path, bool recreate = false) { diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs index cc9889da..a2bd34dc 100644 --- a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs @@ -36,6 +36,7 @@ public partial class CompletionProvider : ICompletionProvider private readonly ISettingsManager settingsManager; private readonly INotificationService notificationService; private readonly IModelIndexService modelIndexService; + private readonly IDownloadService downloadService; private readonly AsyncLock loadLock = new(); private readonly Dictionary entries = new(); @@ -61,12 +62,14 @@ public partial class CompletionProvider : ICompletionProvider public CompletionProvider( ISettingsManager settingsManager, INotificationService notificationService, - IModelIndexService modelIndexService + IModelIndexService modelIndexService, + IDownloadService downloadService ) { this.settingsManager = settingsManager; this.notificationService = notificationService; this.modelIndexService = modelIndexService; + this.downloadService = downloadService; PrepareInsertionText = PrepareInsertionText_Process; @@ -148,6 +151,48 @@ public partial class CompletionProvider : ICompletionProvider ); } + /// + public async Task Setup() + { + var tagsDir = settingsManager.TagsDirectory; + tagsDir.Create(); + + // If tagsDir is empty and no selected, download defaults + if ( + !tagsDir.Info.EnumerateFiles().Any() + && settingsManager.Settings.TagCompletionCsv is null + ) + { + foreach (var remoteCsv in Assets.DefaultCompletionTags) + { + var fileName = remoteCsv.Url.Segments.Last(); + Logger.Info( + "Downloading default tag source {Name} [{Hash}]", + fileName, + remoteCsv.HashSha256[..7] + ); + await downloadService.DownloadToFileAsync( + remoteCsv.Url.ToString(), + tagsDir.JoinFile(fileName) + ); + } + + var defaultFile = tagsDir.JoinFile("danbooru.csv"); + if (!defaultFile.Exists) + { + Logger.Warn("Failed to download default tag source"); + return; + } + + // Set default file as selected + settingsManager.Settings.TagCompletionCsv = defaultFile.Name; + Logger.Debug("Tag completion source set to {Name}", defaultFile.Name); + + // Load default file + BackgroundLoadFromFile(defaultFile); + } + } + /// public async Task LoadFromFile(FilePath path, bool recreate = false) { diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs index 3e091c94..2e7ed85a 100644 --- a/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs @@ -18,6 +18,11 @@ public interface ICompletionProvider /// Func? PrepareInsertionText => null; + /// + /// Downloads default tags and selects one if required. + /// + Task Setup(); + /// /// Load the completion provider from a file. /// diff --git a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs index 439e815c..3e71a877 100644 --- a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs @@ -11,6 +11,7 @@ using DynamicData.Binding; using Microsoft.Extensions.Logging; using SkiaSharp; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; @@ -32,6 +33,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient private readonly IApiFactory apiFactory; private readonly IModelIndexService modelIndexService; private readonly ISettingsManager settingsManager; + private readonly ICompletionProvider completionProvider; [ObservableProperty] [NotifyPropertyChangedFor(nameof(IsConnected), nameof(CanUserConnect))] @@ -86,13 +88,15 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient ILogger logger, IApiFactory apiFactory, IModelIndexService modelIndexService, - ISettingsManager settingsManager + ISettingsManager settingsManager, + ICompletionProvider completionProvider ) { this.logger = logger; this.apiFactory = apiFactory; this.modelIndexService = modelIndexService; this.settingsManager = settingsManager; + this.completionProvider = completionProvider; modelsSource .Connect() @@ -355,6 +359,14 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient throw new ArgumentException("Base package is not ComfyUI", nameof(packagePair)); } + // Setup completion provider + completionProvider + .Setup() + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Error setting up completion provider"); + }); + // Setup image folder links await comfyPackage.SetupInferenceOutputFolderLinks( packagePair.InstalledPackage.FullPath diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs index 6053e1e0..27ec36c3 100644 --- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs @@ -95,7 +95,7 @@ public partial class SettingsViewModel : PageViewModelBase // Inference UI section [ObservableProperty] - private bool isPromptCompletionEnabled; + private bool isPromptCompletionEnabled = true; [ObservableProperty] private IReadOnlyList availableTagCompletionCsvs = Array.Empty(); @@ -104,7 +104,7 @@ public partial class SettingsViewModel : PageViewModelBase private string? selectedTagCompletionCsv; [ObservableProperty] - private bool isCompletionRemoveUnderscoresEnabled; + private bool isCompletionRemoveUnderscoresEnabled = true; [ObservableProperty] private bool isImageViewerPixelGridEnabled = true; diff --git a/StabilityMatrix.Core/Models/Settings/Settings.cs b/StabilityMatrix.Core/Models/Settings/Settings.cs index 38f2d1b6..cb01f14e 100644 --- a/StabilityMatrix.Core/Models/Settings/Settings.cs +++ b/StabilityMatrix.Core/Models/Settings/Settings.cs @@ -58,7 +58,7 @@ public class Settings /// /// Whether prompt auto completion is enabled /// - public bool IsPromptCompletionEnabled { get; set; } + public bool IsPromptCompletionEnabled { get; set; } = true; /// /// Relative path to the tag completion CSV file from 'LibraryDir/Tags' @@ -68,7 +68,7 @@ public class Settings /// /// Whether to remove underscores from completions /// - public bool IsCompletionRemoveUnderscoresEnabled { get; set; } + public bool IsCompletionRemoveUnderscoresEnabled { get; set; } = true; /// /// Whether the Inference Image Viewer shows pixel grids at high zoom levels