From d29067a8d989b539e74d9fb223cc5e0bd659bab7 Mon Sep 17 00:00:00 2001 From: JT Date: Sun, 20 Aug 2023 23:20:28 -0700 Subject: [PATCH] WIP new page for connected checkpoint updates/management --- StabilityMatrix.Avalonia/App.axaml.cs | 3 + .../DesignData/DesignData.cs | 32 +++ .../CheckpointManager/CheckpointFile.cs | 70 +++++++ .../ViewModels/CheckpointsPageViewModel.cs | 4 +- .../ViewModels/NewCheckpointsPageViewModel.cs | 187 ++++++++++++++++++ .../Views/CheckpointBrowserPage.axaml | 2 +- .../Views/NewCheckpointsPage.axaml | 105 ++++++++++ .../Views/NewCheckpointsPage.axaml.cs | 11 ++ .../Models/Api/CivitModelsRequest.cs | 3 + 9 files changed, 414 insertions(+), 3 deletions(-) create mode 100644 StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs create mode 100644 StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml create mode 100644 StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml.cs diff --git a/StabilityMatrix.Avalonia/App.axaml.cs b/StabilityMatrix.Avalonia/App.axaml.cs index 487b73a0..374c1407 100644 --- a/StabilityMatrix.Avalonia/App.axaml.cs +++ b/StabilityMatrix.Avalonia/App.axaml.cs @@ -204,6 +204,7 @@ public sealed class App : Application .AddSingleton() .AddSingleton() .AddSingleton() + .AddSingleton() .AddSingleton() .AddSingleton(); @@ -217,6 +218,7 @@ public sealed class App : Application provider.GetRequiredService(), provider.GetRequiredService(), provider.GetRequiredService(), + provider.GetRequiredService(), provider.GetRequiredService(), }, FooterPages = @@ -285,6 +287,7 @@ public sealed class App : Application services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); // Dialogs services.AddTransient(); diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs index b0fc04ba..a68a7fa4 100644 --- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs +++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs @@ -223,6 +223,35 @@ public static class DesignData }) }; + NewCheckpointsPageViewModel.AllCheckpoints = new ObservableCollection + { + new() + { + FilePath = "~/Models/StableDiffusion/electricity-light.safetensors", + Title = "Auroral Background", + PreviewImagePath = "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/" + + "78fd2a0a-42b6-42b0-9815-81cb11bb3d05/00009-2423234823.jpeg", + ConnectedModel = new ConnectedModelInfo + { + VersionName = "Lightning Auroral", + BaseModel = "SD 1.5", + ModelName = "Auroral Background", + ModelType = CivitModelType.Model, + FileMetadata = new CivitFileMetadata + { + Format = CivitModelFormat.SafeTensor, + Fp = CivitModelFpType.fp16, + Size = CivitModelSize.pruned, + } + } + }, + new() + { + FilePath = "~/Models/Lora/model.safetensors", + Title = "Some model" + } + }; + ProgressManagerViewModel.ProgressItems = new ObservableCollection { new(new ProgressItem(Guid.NewGuid(), "Test File.exe", @@ -273,6 +302,9 @@ public static class DesignData public static CheckpointsPageViewModel CheckpointsPageViewModel => Services.GetRequiredService(); + public static NewCheckpointsPageViewModel NewCheckpointsPageViewModel => + Services.GetRequiredService(); + public static SettingsViewModel SettingsViewModel => Services.GetRequiredService(); diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs index a8f83521..8ecfff45 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs @@ -14,6 +14,7 @@ using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Processes; @@ -46,6 +47,7 @@ public partial class CheckpointFile : ViewModelBase public bool IsConnectedModel => ConnectedModel != null; [ObservableProperty] private bool isLoading; + [ObservableProperty] private CivitModelType modelType; public string FileName => Path.GetFileName((string?) FilePath); @@ -223,6 +225,39 @@ public partial class CheckpointFile : ViewModelBase yield return checkpointFile; } } + + public static IEnumerable GetAllCheckpointFiles(string modelsDirectory) + { + foreach (var file in Directory.EnumerateFiles(modelsDirectory, "*.*", SearchOption.AllDirectories)) + { + if (!SupportedCheckpointExtensions.Any(ext => file.Contains(ext))) + continue; + + var checkpointFile = new CheckpointFile + { + Title = Path.GetFileNameWithoutExtension(file), + FilePath = file, + }; + + var jsonPath = Path.Combine(Path.GetDirectoryName(file), + Path.GetFileNameWithoutExtension(file) + ".cm-info.json"); + + if (File.Exists(jsonPath)) + { + var json = File.ReadAllText(jsonPath); + var connectedModelInfo = ConnectedModelInfo.FromJson(json); + checkpointFile.ConnectedModel = connectedModelInfo; + checkpointFile.ModelType = GetCivitModelType(file); + } + + checkpointFile.PreviewImagePath = SupportedImageExtensions + .Select(ext => Path.Combine(Path.GetDirectoryName(file), + $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}")).Where(File.Exists) + .FirstOrDefault(); + + yield return checkpointFile; + } + } /// /// Index with progress reporting. @@ -238,4 +273,39 @@ public partial class CheckpointFile : ViewModelBase yield return checkpointFile; } } + + private static CivitModelType GetCivitModelType(string filePath) + { + if (filePath.Contains(SharedFolderType.StableDiffusion.ToString())) + { + return CivitModelType.Checkpoint; + } + + if (filePath.Contains(SharedFolderType.ControlNet.ToString())) + { + return CivitModelType.Controlnet; + } + + if (filePath.Contains(SharedFolderType.Lora.ToString())) + { + return CivitModelType.LORA; + } + + if (filePath.Contains(SharedFolderType.TextualInversion.ToString())) + { + return CivitModelType.TextualInversion; + } + + if (filePath.Contains(SharedFolderType.Hypernetwork.ToString())) + { + return CivitModelType.Hypernetwork; + } + + if (filePath.Contains(SharedFolderType.LyCORIS.ToString())) + { + return CivitModelType.LoCon; + } + + return CivitModelType.Unknown; + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs index 19d1e84c..19f99d84 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs @@ -70,7 +70,7 @@ public partial class CheckpointsPageViewModel : PageViewModelBase this.downloadService = downloadService; this.modelFinder = modelFinder; } - + public override async Task OnLoadedAsync() { DisplayedCheckpointFolders = CheckpointFolders; @@ -147,7 +147,7 @@ public partial class CheckpointsPageViewModel : PageViewModelBase var indexTasks = folders.Select(async f => { var checkpointFolder = - new CheckpointManager.CheckpointFolder(settingsManager, downloadService, modelFinder) + new CheckpointFolder(settingsManager, downloadService, modelFinder) { Title = Path.GetFileName(f), DirectoryPath = f, diff --git a/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs new file mode 100644 index 00000000..da2c4549 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs @@ -0,0 +1,187 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using AsyncAwaitBestPractices; +using Avalonia.Controls; +using Avalonia.Controls.Notifications; +using AvaloniaEdit.Utils; +using CommunityToolkit.Mvvm.ComponentModel; +using FluentAvalonia.UI.Controls; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Avalonia.Views; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Database; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api; +using StabilityMatrix.Core.Services; +using Symbol = FluentIcons.Common.Symbol; +using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; + +namespace StabilityMatrix.Avalonia.ViewModels; + +[View(typeof(NewCheckpointsPage))] +public partial class NewCheckpointsPageViewModel : PageViewModelBase +{ + private readonly ISettingsManager settingsManager; + private readonly ILiteDbContext liteDbContext; + private readonly ICivitApi civitApi; + private readonly ServiceManager dialogFactory; + private readonly INotificationService notificationService; + public override string Title => "Checkpoint Manager"; + public override IconSource IconSource => new SymbolIconSource + {Symbol = Symbol.Cellular5g, IsFilled = true}; + + public NewCheckpointsPageViewModel(ISettingsManager settingsManager, ILiteDbContext liteDbContext, + ICivitApi civitApi, ServiceManager dialogFactory, INotificationService notificationService) + { + this.settingsManager = settingsManager; + this.liteDbContext = liteDbContext; + this.civitApi = civitApi; + this.dialogFactory = dialogFactory; + this.notificationService = notificationService; + } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(ConnectedCheckpoints))] + [NotifyPropertyChangedFor(nameof(NonConnectedCheckpoints))] + private ObservableCollection allCheckpoints = new(); + + [ObservableProperty] + private ObservableCollection civitModels = new(); + + public ObservableCollection ConnectedCheckpoints => new( + AllCheckpoints.Where(x => x.IsConnectedModel) + .OrderBy(x => x.ConnectedModel!.ModelName) + .ThenBy(x => x.ModelType) + .GroupBy(x => x.ConnectedModel!.ModelId) + .Select(x => x.First())); + + public ObservableCollection NonConnectedCheckpoints => new( + AllCheckpoints.Where(x => !x.IsConnectedModel).OrderBy(x => x.ModelType)); + + public override async Task OnLoadedAsync() + { + if (Design.IsDesignMode) return; + + var files = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory); + AllCheckpoints = new ObservableCollection(files); + + var connectedModelIds = ConnectedCheckpoints.Select(x => x.ConnectedModel.ModelId); + var modelRequest = new CivitModelsRequest + { + CommaSeparatedModelIds = string.Join(',', connectedModelIds) + }; + + // See if query is cached + var cachedQuery = await liteDbContext.CivitModelQueryCache + .IncludeAll() + .FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest)); + + // If cached, update model cards + if (cachedQuery is not null) + { + CivitModels = new ObservableCollection(cachedQuery.Items); + + // Start remote query (background mode) + // Skip when last query was less than 2 min ago + var timeSinceCache = DateTimeOffset.UtcNow - cachedQuery.InsertedAt; + if (timeSinceCache?.TotalMinutes >= 2) + { + CivitQuery(modelRequest).SafeFireAndForget(); + } + } + else + { + await CivitQuery(modelRequest); + } + } + + public async Task ShowVersionDialog(int modelId) + { + var model = CivitModels.FirstOrDefault(m => m.Id == modelId); + if (model == null) + { + notificationService.Show(new Notification("Model has no versions available", + "This model has no versions available for download", NotificationType.Warning)); + return; + } + var versions = model.ModelVersions; + if (versions is null || versions.Count == 0) + { + notificationService.Show(new Notification("Model has no versions available", + "This model has no versions available for download", NotificationType.Warning)); + return; + } + + var dialog = new BetterContentDialog + { + Title = model.Name, + IsPrimaryButtonEnabled = false, + IsSecondaryButtonEnabled = false, + IsFooterVisible = false, + MaxDialogWidth = 750, + }; + + var viewModel = dialogFactory.Get(); + viewModel.Dialog = dialog; + viewModel.Versions = versions.Select(version => + new ModelVersionViewModel( + settingsManager.Settings.InstalledModelHashes ?? new HashSet(), version)) + .ToImmutableArray(); + viewModel.SelectedVersionViewModel = viewModel.Versions[0]; + + dialog.Content = new SelectModelVersionDialog + { + DataContext = viewModel + }; + + var result = await dialog.ShowAsync(); + + if (result != ContentDialogResult.Primary) + { + return; + } + + var selectedVersion = viewModel?.SelectedVersionViewModel?.ModelVersion; + var selectedFile = viewModel?.SelectedFile?.CivitFile; + } + + private async Task CivitQuery(CivitModelsRequest request) + { + var modelResponse = await civitApi.GetModels(request); + var models = modelResponse.Items; + // Filter out unknown model types and archived/taken-down models + models = models.Where(m => m.Type.ConvertTo() > 0) + .Where(m => m.Mode == null).ToList(); + + // Database update calls will invoke `OnModelsUpdated` + // Add to database + await liteDbContext.UpsertCivitModelAsync(models); + // Add as cache entry + var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(new CivitModelQueryCacheEntry + { + Id = ObjectHash.GetMd5Guid(request), + InsertedAt = DateTimeOffset.UtcNow, + Request = request, + Items = models, + Metadata = modelResponse.Metadata + }); + + if (cacheNew) + { + CivitModels = new ObservableCollection(models); + } + } +} diff --git a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml index a8352c5d..19a3b3ed 100644 --- a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml +++ b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml @@ -40,7 +40,7 @@ Margin="0,8,0,8" Height="300" StretchDirection="Both" - CornerRadius="4" + CornerRadius="8" VerticalContentAlignment="Top" HorizontalContentAlignment="Center" Source="{Binding CardImage}" diff --git a/StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml b/StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml new file mode 100644 index 00000000..05bfb412 --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml @@ -0,0 +1,105 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + [AliasAs("baseModels")] public string? BaseModel { get; set; } + + [AliasAs("ids")] + public string CommaSeparatedModelIds { get; set; } }