using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.ComponentModel; using System.Diagnostics; using System.Linq; using System.Net.Http; using System.Reactive; using System.Reactive.Linq; using System.Reactive.Threading.Tasks; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Collections; using Avalonia.Controls; using AvaloniaEdit.Utils; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using NLog; using Refit; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Database; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api; using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Services; using Notification = Avalonia.Controls.Notifications.Notification; using Symbol = FluentIcons.Common.Symbol; using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(CheckpointBrowserPage))] public partial class CheckpointBrowserViewModel : PageViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly ICivitApi civitApi; private readonly IDownloadService downloadService; private readonly ISettingsManager settingsManager; private readonly ServiceManager dialogFactory; private readonly ILiteDbContext liteDbContext; private readonly INotificationService notificationService; private const int MaxModelsPerPage = 14; private LRUCache cache = new(50); [ObservableProperty] private ObservableCollection? modelCards; [ObservableProperty] private DataGridCollectionView? modelCardsView; [ObservableProperty] private string searchQuery = string.Empty; [ObservableProperty] private bool showNsfw; [ObservableProperty] private bool showMainLoadingSpinner; [ObservableProperty] private CivitPeriod selectedPeriod = CivitPeriod.Month; [ObservableProperty] private CivitSortMode sortMode = CivitSortMode.HighestRated; [ObservableProperty] private CivitModelType selectedModelType = CivitModelType.Checkpoint; [ObservableProperty] private int currentPageNumber; [ObservableProperty] private int totalPages; [ObservableProperty] private bool hasSearched; [ObservableProperty] private bool canGoToNextPage; [ObservableProperty] private bool canGoToPreviousPage; [ObservableProperty] private bool canGoToFirstPage; [ObservableProperty] private bool canGoToLastPage; [ObservableProperty] private bool isIndeterminate; [ObservableProperty] private bool noResultsFound; [ObservableProperty] private string noResultsText = string.Empty; [ObservableProperty] private string selectedBaseModelType = "All"; private List allModelCards = new(); public IEnumerable AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast(); public IEnumerable AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast(); public IEnumerable AllModelTypes => Enum.GetValues(typeof(CivitModelType)) .Cast() .Where(t => t == CivitModelType.All || t.ConvertTo() > 0) .OrderBy(t => t.ToString()); public List BaseModelOptions => new() {"All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0"}; public CheckpointBrowserViewModel( ICivitApi civitApi, IDownloadService downloadService, ISettingsManager settingsManager, ServiceManager dialogFactory, ILiteDbContext liteDbContext, INotificationService notificationService) { this.civitApi = civitApi; this.downloadService = downloadService; this.settingsManager = settingsManager; this.dialogFactory = dialogFactory; this.liteDbContext = liteDbContext; this.notificationService = notificationService; CurrentPageNumber = 1; CanGoToNextPage = true; CanGoToLastPage = true; Observable .FromEventPattern(this, nameof(PropertyChanged)) .Where(x => x.EventArgs.PropertyName == nameof(CurrentPageNumber)) .Throttle(TimeSpan.FromMilliseconds(250)) .Select(_ => CurrentPageNumber) .Where(page => page <= TotalPages && page > 0) .ObserveOn(SynchronizationContext.Current) .Subscribe(_ => TrySearchAgain(false).SafeFireAndForget(), err => Logger.Error(err)); } public override void OnLoaded() { if (Design.IsDesignMode) return; var searchOptions = settingsManager.Settings.ModelSearchOptions; // Fix SelectedModelType if someone had selected the obsolete "Model" option if (searchOptions is {SelectedModelType: CivitModelType.Model}) { settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( SelectedPeriod, SortMode, CivitModelType.Checkpoint, SelectedBaseModelType)); searchOptions = settingsManager.Settings.ModelSearchOptions; } SelectedPeriod = searchOptions?.SelectedPeriod ?? CivitPeriod.Month; SortMode = searchOptions?.SortMode ?? CivitSortMode.HighestRated; SelectedModelType = searchOptions?.SelectedModelType ?? CivitModelType.Checkpoint; SelectedBaseModelType = searchOptions?.SelectedBaseModelType ?? "All"; ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled; settingsManager.RelayPropertyFor(this, model => model.ShowNsfw, settings => settings.ModelBrowserNsfwEnabled); } /// /// Filter predicate for model cards /// private bool FilterModelCardsPredicate(object? item) { if (item is not CheckpointBrowserCardViewModel card) return false; return !card.CivitModel.Nsfw || ShowNsfw; } /// /// Background update task /// private async Task CivitModelQuery(CivitModelsRequest request) { var timer = Stopwatch.StartNew(); var queryText = request.Query; try { var modelsResponse = await civitApi.GetModels(request); var models = modelsResponse.Items; if (models is null) { Logger.Debug("CivitAI Query {Text} returned no results (in {Elapsed:F1} s)", queryText, timer.Elapsed.TotalSeconds); return; } Logger.Debug("CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)", queryText, models.Count, timer.Elapsed.TotalSeconds); var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList(); if (unknown.Any()) { var names = unknown.Select(m => m.Name).ToList(); Logger.Warn("Excluded {Unknown} unknown model types: {Models}", unknown.Count, names); } // Filter out unknown model types and archived/taken-down models models = models.Where(m => m.Type.ConvertTo() > 0) .Where(m => m.Mode == null).ToList(); // Database update calls will invoke `OnModelsUpdated` // Add to database await liteDbContext.UpsertCivitModelAsync(models); // Add as cache entry var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(new() { Id = ObjectHash.GetMd5Guid(request), InsertedAt = DateTimeOffset.UtcNow, Request = request, Items = models, Metadata = modelsResponse.Metadata }); if (cacheNew) { Logger.Debug("New cache entry, updating model cards"); UpdateModelCards(models, modelsResponse.Metadata); } else { Logger.Debug("Cache entry already exists, not updating model cards"); } } catch (OperationCanceledException) { notificationService.Show(new Notification("Request to CivitAI timed out", "Please try again in a few minutes")); Logger.Warn($"CivitAI query timed out ({request})"); } catch (HttpRequestException e) { notificationService.Show(new Notification("CivitAI can't be reached right now", "Please try again in a few minutes")); Logger.Warn(e, $"CivitAI query HttpRequestException ({request})"); } catch (ApiException e) { notificationService.Show(new Notification("CivitAI can't be reached right now", "Please try again in a few minutes")); Logger.Warn(e, $"CivitAI query ApiException ({request})"); } catch (Exception e) { notificationService.Show(new Notification("CivitAI can't be reached right now", $"Unknown exception during CivitAI query: {e.GetType().Name}")); Logger.Error(e, $"CivitAI query unknown exception ({request})"); } finally { ShowMainLoadingSpinner = false; UpdateResultsText(); } } /// /// Updates model cards using api response object. /// private void UpdateModelCards(IEnumerable? models, CivitMetadata? metadata) { if (models is null) { ModelCards?.Clear(); } else { var updateCards = models .Select(model => { var cachedViewModel = cache.Get(model.Id); if (cachedViewModel != null) { if (!cachedViewModel.IsImporting) { cache.Remove(model.Id); } return cachedViewModel; } var newCard = dialogFactory.Get(vm => { vm.CivitModel = model; vm.OnDownloadStart = viewModel => { if (cache.Get(viewModel.CivitModel.Id) != null) return; cache.Add(viewModel.CivitModel.Id, viewModel); }; return vm; }); return newCard; }).ToList(); allModelCards = updateCards; var filteredCards = updateCards.Where(FilterModelCardsPredicate); if (SortMode == CivitSortMode.Installed) { filteredCards = filteredCards.OrderByDescending(x => x.UpdateCardText == "Update Available"); } ModelCards =new ObservableCollection(filteredCards); } TotalPages = metadata?.TotalPages ?? 1; CanGoToFirstPage = CurrentPageNumber != 1; CanGoToPreviousPage = CurrentPageNumber > 1; CanGoToNextPage = CurrentPageNumber < TotalPages; CanGoToLastPage = CurrentPageNumber != TotalPages; // Status update ShowMainLoadingSpinner = false; IsIndeterminate = false; HasSearched = true; } private string previousSearchQuery = string.Empty; [RelayCommand] private async Task SearchModels() { var timer = Stopwatch.StartNew(); if (SearchQuery != previousSearchQuery) { // Reset page number CurrentPageNumber = 1; previousSearchQuery = SearchQuery; } // Build request var modelRequest = new CivitModelsRequest { Limit = MaxModelsPerPage, Nsfw = "true", // Handled by local view filter Sort = SortMode, Period = SelectedPeriod, Page = CurrentPageNumber }; if (SearchQuery.StartsWith("#")) { modelRequest.Tag = SearchQuery[1..]; } else if (SearchQuery.StartsWith("@")) { modelRequest.Username = SearchQuery[1..]; } else { modelRequest.Query = SearchQuery; } if (SelectedModelType != CivitModelType.All) { modelRequest.Types = new[] {SelectedModelType}; } if (SelectedBaseModelType != "All") { modelRequest.BaseModel = SelectedBaseModelType; } if (SortMode == CivitSortMode.Installed) { var connectedModels = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory) .Where(c => c.IsConnectedModel); if (SelectedModelType != CivitModelType.All) { connectedModels = connectedModels.Where(c => c.ModelType == SelectedModelType); } modelRequest = new CivitModelsRequest { CommaSeparatedModelIds = string.Join(",", connectedModels.Select(c => c.ConnectedModel!.ModelId).GroupBy(m => m) .Select(g => g.First())), Types = SelectedModelType == CivitModelType.All ? null : new[] {SelectedModelType} }; } // See if query is cached var cachedQuery = await liteDbContext.CivitModelQueryCache .IncludeAll() .FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest)); // If cached, update model cards if (cachedQuery is not null) { var elapsed = timer.Elapsed; Logger.Debug("Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)", SearchQuery, modelRequest.GetHashCode(), elapsed.TotalSeconds); UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata); // Start remote query (background mode) // Skip when last query was less than 2 min ago var timeSinceCache = DateTimeOffset.UtcNow - cachedQuery.InsertedAt; if (timeSinceCache?.TotalMinutes >= 2) { CivitModelQuery(modelRequest).SafeFireAndForget(); Logger.Debug( "Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query", timeSinceCache.Value.TotalSeconds); } } else { // Not cached, wait for remote query ShowMainLoadingSpinner = true; await CivitModelQuery(modelRequest); } UpdateResultsText(); } public void FirstPage() { CurrentPageNumber = 1; } public void PreviousPage() { if (CurrentPageNumber == 1) return; CurrentPageNumber--; } public void NextPage() { if (CurrentPageNumber == TotalPages) return; CurrentPageNumber++; } public void LastPage() { CurrentPageNumber = TotalPages; } partial void OnShowNsfwChanged(bool value) { settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value); // ModelCardsView?.Refresh(); var updateCards = allModelCards .Where(FilterModelCardsPredicate); ModelCards = new ObservableCollection(updateCards); if (!HasSearched) return; UpdateResultsText(); } partial void OnSelectedPeriodChanged(CivitPeriod value) { TrySearchAgain().SafeFireAndForget(); settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( value, SortMode, SelectedModelType, SelectedBaseModelType)); } partial void OnSortModeChanged(CivitSortMode value) { TrySearchAgain().SafeFireAndForget(); settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( SelectedPeriod, value, SelectedModelType, SelectedBaseModelType)); } partial void OnSelectedModelTypeChanged(CivitModelType value) { TrySearchAgain().SafeFireAndForget(); settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( SelectedPeriod, SortMode, value, SelectedBaseModelType)); } partial void OnSelectedBaseModelTypeChanged(string value) { TrySearchAgain().SafeFireAndForget(); settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( SelectedPeriod, SortMode, SelectedModelType, value)); } private async Task TrySearchAgain(bool shouldUpdatePageNumber = true) { if (!HasSearched) return; ModelCards?.Clear(); if (shouldUpdatePageNumber) { CurrentPageNumber = 1; } // execute command instead of calling method directly so that the IsRunning property gets updated await SearchModelsCommand.ExecuteAsync(null); } private void UpdateResultsText() { NoResultsFound = ModelCards?.Count <= 0; NoResultsText = allModelCards.Count > 0 ? $"{allModelCards.Count} results hidden by filters" : "No results found"; } public override string Title => "Model Browser"; public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true }; }