diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs index 42bad566..f7cc495e 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs @@ -51,48 +51,93 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase private readonly ILiteDbContext liteDbContext; private readonly INotificationService notificationService; private const int MaxModelsPerPage = 14; - private LRUCache cache = new(50); - - [ObservableProperty] private ObservableCollection? modelCards; - [ObservableProperty] private DataGridCollectionView? modelCardsView; - - [ObservableProperty] private string searchQuery = string.Empty; - [ObservableProperty] private bool showNsfw; - [ObservableProperty] private bool showMainLoadingSpinner; - [ObservableProperty] private CivitPeriod selectedPeriod = CivitPeriod.Month; - [ObservableProperty] private CivitSortMode sortMode = CivitSortMode.HighestRated; - [ObservableProperty] private CivitModelType selectedModelType = CivitModelType.Checkpoint; - [ObservableProperty] private int currentPageNumber; - [ObservableProperty] private int totalPages; - [ObservableProperty] private bool hasSearched; - [ObservableProperty] private bool canGoToNextPage; - [ObservableProperty] private bool canGoToPreviousPage; - [ObservableProperty] private bool canGoToFirstPage; - [ObservableProperty] private bool canGoToLastPage; - [ObservableProperty] private bool isIndeterminate; - [ObservableProperty] private bool noResultsFound; - [ObservableProperty] private string noResultsText = string.Empty; - [ObservableProperty] private string selectedBaseModelType = "All"; - + private LRUCache< + int /* model id */ + , + CheckpointBrowserCardViewModel + > cache = new(50); + + [ObservableProperty] + private ObservableCollection? modelCards; + + [ObservableProperty] + private DataGridCollectionView? modelCardsView; + + [ObservableProperty] + private string searchQuery = string.Empty; + + [ObservableProperty] + private bool showNsfw; + + [ObservableProperty] + private bool showMainLoadingSpinner; + + [ObservableProperty] + private CivitPeriod selectedPeriod = CivitPeriod.Month; + + [ObservableProperty] + private CivitSortMode sortMode = CivitSortMode.HighestRated; + + [ObservableProperty] + private CivitModelType selectedModelType = CivitModelType.Checkpoint; + + [ObservableProperty] + private int currentPageNumber; + + [ObservableProperty] + private int totalPages; + + [ObservableProperty] + private bool hasSearched; + + [ObservableProperty] + private bool canGoToNextPage; + + [ObservableProperty] + private bool canGoToPreviousPage; + + [ObservableProperty] + private bool canGoToFirstPage; + + [ObservableProperty] + private bool canGoToLastPage; + + [ObservableProperty] + private bool isIndeterminate; + + [ObservableProperty] + private bool noResultsFound; + + [ObservableProperty] + private string noResultsText = string.Empty; + + [ObservableProperty] + private string selectedBaseModelType = "All"; + private List allModelCards = new(); - - public IEnumerable AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast(); - public IEnumerable AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast(); - public IEnumerable AllModelTypes => Enum.GetValues(typeof(CivitModelType)) - .Cast() - .Where(t => t == CivitModelType.All || t.ConvertTo() > 0) - .OrderBy(t => t.ToString()); + public IEnumerable AllCivitPeriods => + Enum.GetValues(typeof(CivitPeriod)).Cast(); + public IEnumerable AllSortModes => + Enum.GetValues(typeof(CivitSortMode)).Cast(); + + public IEnumerable AllModelTypes => + Enum.GetValues(typeof(CivitModelType)) + .Cast() + .Where(t => t == CivitModelType.All || t.ConvertTo() > 0) + .OrderBy(t => t.ToString()); - public List BaseModelOptions => new() {"All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0"}; + public List BaseModelOptions => + new() { "All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0" }; public CheckpointBrowserViewModel( - ICivitApi civitApi, - IDownloadService downloadService, + ICivitApi civitApi, + IDownloadService downloadService, ISettingsManager settingsManager, ServiceManager dialogFactory, ILiteDbContext liteDbContext, - INotificationService notificationService) + INotificationService notificationService + ) { this.civitApi = civitApi; this.downloadService = downloadService; @@ -117,27 +162,38 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase public override void OnLoaded() { - if (Design.IsDesignMode) return; - + if (Design.IsDesignMode) + return; + var searchOptions = settingsManager.Settings.ModelSearchOptions; - + // Fix SelectedModelType if someone had selected the obsolete "Model" option - if (searchOptions is {SelectedModelType: CivitModelType.Model}) + if (searchOptions is { SelectedModelType: CivitModelType.Model }) { - settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( - SelectedPeriod, SortMode, CivitModelType.Checkpoint, SelectedBaseModelType)); + settingsManager.Transaction( + s => + s.ModelSearchOptions = new ModelSearchOptions( + SelectedPeriod, + SortMode, + CivitModelType.Checkpoint, + SelectedBaseModelType + ) + ); searchOptions = settingsManager.Settings.ModelSearchOptions; } - + SelectedPeriod = searchOptions?.SelectedPeriod ?? CivitPeriod.Month; SortMode = searchOptions?.SortMode ?? CivitSortMode.HighestRated; SelectedModelType = searchOptions?.SelectedModelType ?? CivitModelType.Checkpoint; SelectedBaseModelType = searchOptions?.SelectedBaseModelType ?? "All"; - + ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled; - - settingsManager.RelayPropertyFor(this, model => model.ShowNsfw, - settings => settings.ModelBrowserNsfwEnabled); + + settingsManager.RelayPropertyFor( + this, + model => model.ShowNsfw, + settings => settings.ModelBrowserNsfwEnabled + ); } /// @@ -145,7 +201,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase /// private bool FilterModelCardsPredicate(object? item) { - if (item is not CheckpointBrowserCardViewModel card) return false; + if (item is not CheckpointBrowserCardViewModel card) + return false; return !card.CivitModel.Nsfw || ShowNsfw; } @@ -162,38 +219,52 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase var models = modelsResponse.Items; if (models is null) { - Logger.Debug("CivitAI Query {Text} returned no results (in {Elapsed:F1} s)", - queryText, timer.Elapsed.TotalSeconds); + Logger.Debug( + "CivitAI Query {Text} returned no results (in {Elapsed:F1} s)", + queryText, + timer.Elapsed.TotalSeconds + ); return; } - Logger.Debug("CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)", - queryText, models.Count, timer.Elapsed.TotalSeconds); + Logger.Debug( + "CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)", + queryText, + models.Count, + timer.Elapsed.TotalSeconds + ); var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList(); if (unknown.Any()) { var names = unknown.Select(m => m.Name).ToList(); - Logger.Warn("Excluded {Unknown} unknown model types: {Models}", unknown.Count, - names); + Logger.Warn( + "Excluded {Unknown} unknown model types: {Models}", + unknown.Count, + names + ); } // Filter out unknown model types and archived/taken-down models - models = models.Where(m => m.Type.ConvertTo() > 0) - .Where(m => m.Mode == null).ToList(); + models = models + .Where(m => m.Type.ConvertTo() > 0) + .Where(m => m.Mode == null) + .ToList(); // Database update calls will invoke `OnModelsUpdated` // Add to database await liteDbContext.UpsertCivitModelAsync(models); // Add as cache entry - var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(new() - { - Id = ObjectHash.GetMd5Guid(request), - InsertedAt = DateTimeOffset.UtcNow, - Request = request, - Items = models, - Metadata = modelsResponse.Metadata - }); + var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync( + new() + { + Id = ObjectHash.GetMd5Guid(request), + InsertedAt = DateTimeOffset.UtcNow, + Request = request, + Items = models, + Metadata = modelsResponse.Metadata + } + ); if (cacheNew) { @@ -207,26 +278,42 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase } catch (OperationCanceledException) { - notificationService.Show(new Notification("Request to CivitAI timed out", - "Please try again in a few minutes")); + notificationService.Show( + new Notification( + "Request to CivitAI timed out", + "Please try again in a few minutes" + ) + ); Logger.Warn($"CivitAI query timed out ({request})"); } catch (HttpRequestException e) { - notificationService.Show(new Notification("CivitAI can't be reached right now", - "Please try again in a few minutes")); + notificationService.Show( + new Notification( + "CivitAI can't be reached right now", + "Please try again in a few minutes" + ) + ); Logger.Warn(e, $"CivitAI query HttpRequestException ({request})"); } catch (ApiException e) { - notificationService.Show(new Notification("CivitAI can't be reached right now", - "Please try again in a few minutes")); + notificationService.Show( + new Notification( + "CivitAI can't be reached right now", + "Please try again in a few minutes" + ) + ); Logger.Warn(e, $"CivitAI query ApiException ({request})"); } catch (Exception e) { - notificationService.Show(new Notification("CivitAI can't be reached right now", - $"Unknown exception during CivitAI query: {e.GetType().Name}")); + notificationService.Show( + new Notification( + "CivitAI can't be reached right now", + $"Unknown exception during CivitAI query: {e.GetType().Name}" + ) + ); Logger.Error(e, $"CivitAI query unknown exception ({request})"); } finally @@ -235,11 +322,11 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase UpdateResultsText(); } } - + /// /// Updates model cards using api response object. /// - private void UpdateModelCards(IEnumerable? models, CivitMetadata? metadata) + private void UpdateModelCards(IEnumerable? models, CivitMetadata? metadata) { if (models is null) { @@ -266,26 +353,29 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase vm.CivitModel = model; vm.OnDownloadStart = viewModel => { - if (cache.Get(viewModel.CivitModel.Id) != null) return; + if (cache.Get(viewModel.CivitModel.Id) != null) + return; cache.Add(viewModel.CivitModel.Id, viewModel); }; return vm; }); - + return newCard; - }).ToList(); - + }) + .ToList(); + allModelCards = updateCards; var filteredCards = updateCards.Where(FilterModelCardsPredicate); if (SortMode == CivitSortMode.Installed) { - filteredCards = - filteredCards.OrderByDescending(x => x.UpdateCardText == "Update Available"); + filteredCards = filteredCards.OrderByDescending( + x => x.UpdateCardText == "Update Available" + ); } - - ModelCards =new ObservableCollection(filteredCards); + + ModelCards = new ObservableCollection(filteredCards); } TotalPages = metadata?.TotalPages ?? 1; CanGoToFirstPage = CurrentPageNumber != 1; @@ -304,14 +394,14 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase private async Task SearchModels() { var timer = Stopwatch.StartNew(); - + if (SearchQuery != previousSearchQuery) { // Reset page number CurrentPageNumber = 1; previousSearchQuery = SearchQuery; } - + // Build request var modelRequest = new CivitModelsRequest { @@ -334,10 +424,10 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase { modelRequest.Query = SearchQuery; } - + if (SelectedModelType != CivitModelType.All) { - modelRequest.Types = new[] {SelectedModelType}; + modelRequest.Types = new[] { SelectedModelType }; } if (SelectedBaseModelType != "All") @@ -347,9 +437,9 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase if (SortMode == CivitSortMode.Installed) { - var connectedModels = - CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory) - .Where(c => c.IsConnectedModel); + var connectedModels = CheckpointFile + .GetAllCheckpointFiles(settingsManager.ModelsDirectory) + .Where(c => c.IsConnectedModel); if (SelectedModelType != CivitModelType.All) { @@ -358,24 +448,34 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase modelRequest = new CivitModelsRequest { - CommaSeparatedModelIds = string.Join(",", - connectedModels.Select(c => c.ConnectedModel!.ModelId).GroupBy(m => m) - .Select(g => g.First())), - Types = SelectedModelType == CivitModelType.All ? null : new[] {SelectedModelType} + CommaSeparatedModelIds = string.Join( + ",", + connectedModels + .Select(c => c.ConnectedModel!.ModelId) + .GroupBy(m => m) + .Select(g => g.First()) + ), + Types = + SelectedModelType == CivitModelType.All ? null : new[] { SelectedModelType }, + Page = CurrentPageNumber, }; } - + // See if query is cached var cachedQuery = await liteDbContext.CivitModelQueryCache .IncludeAll() .FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest)); - + // If cached, update model cards if (cachedQuery is not null) { var elapsed = timer.Elapsed; - Logger.Debug("Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)", - SearchQuery, modelRequest.GetHashCode(), elapsed.TotalSeconds); + Logger.Debug( + "Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)", + SearchQuery, + modelRequest.GetHashCode(), + elapsed.TotalSeconds + ); UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata); // Start remote query (background mode) @@ -386,7 +486,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase CivitModelQuery(modelRequest).SafeFireAndForget(); Logger.Debug( "Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query", - timeSinceCache.Value.TotalSeconds); + timeSinceCache.Value.TotalSeconds + ); } } else @@ -395,7 +496,7 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase ShowMainLoadingSpinner = true; await CivitModelQuery(modelRequest); } - + UpdateResultsText(); } @@ -406,17 +507,17 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase public void PreviousPage() { - if (CurrentPageNumber == 1) + if (CurrentPageNumber == 1) return; - + CurrentPageNumber--; } - + public void NextPage() { - if (CurrentPageNumber == TotalPages) + if (CurrentPageNumber == TotalPages) return; - + CurrentPageNumber++; } @@ -429,46 +530,75 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase { settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value); // ModelCardsView?.Refresh(); - var updateCards = allModelCards - .Where(FilterModelCardsPredicate); + var updateCards = allModelCards.Where(FilterModelCardsPredicate); ModelCards = new ObservableCollection(updateCards); - if (!HasSearched) return; - + if (!HasSearched) + return; + UpdateResultsText(); } partial void OnSelectedPeriodChanged(CivitPeriod value) { TrySearchAgain().SafeFireAndForget(); - settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( - value, SortMode, SelectedModelType, SelectedBaseModelType)); + settingsManager.Transaction( + s => + s.ModelSearchOptions = new ModelSearchOptions( + value, + SortMode, + SelectedModelType, + SelectedBaseModelType + ) + ); } partial void OnSortModeChanged(CivitSortMode value) { TrySearchAgain().SafeFireAndForget(); - settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( - SelectedPeriod, value, SelectedModelType, SelectedBaseModelType)); + settingsManager.Transaction( + s => + s.ModelSearchOptions = new ModelSearchOptions( + SelectedPeriod, + value, + SelectedModelType, + SelectedBaseModelType + ) + ); } - + partial void OnSelectedModelTypeChanged(CivitModelType value) { TrySearchAgain().SafeFireAndForget(); - settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( - SelectedPeriod, SortMode, value, SelectedBaseModelType)); + settingsManager.Transaction( + s => + s.ModelSearchOptions = new ModelSearchOptions( + SelectedPeriod, + SortMode, + value, + SelectedBaseModelType + ) + ); } partial void OnSelectedBaseModelTypeChanged(string value) { TrySearchAgain().SafeFireAndForget(); - settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( - SelectedPeriod, SortMode, SelectedModelType, value)); + settingsManager.Transaction( + s => + s.ModelSearchOptions = new ModelSearchOptions( + SelectedPeriod, + SortMode, + SelectedModelType, + value + ) + ); } private async Task TrySearchAgain(bool shouldUpdatePageNumber = true) { - if (!HasSearched) return; + if (!HasSearched) + return; ModelCards?.Clear(); if (shouldUpdatePageNumber) @@ -483,11 +613,13 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase private void UpdateResultsText() { NoResultsFound = ModelCards?.Count <= 0; - NoResultsText = allModelCards.Count > 0 - ? $"{allModelCards.Count} results hidden by filters" - : "No results found"; + NoResultsText = + allModelCards.Count > 0 + ? $"{allModelCards.Count} results hidden by filters" + : "No results found"; } public override string Title => "Model Browser"; - public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true }; + public override IconSource IconSource => + new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true }; }