Browse Source

Merge pull request #236 from ionite34/fix-installed-model-search

Send page number in request for Installed models
pull/117/head
JT 1 year ago committed by GitHub
parent
commit
3ecb3cdb0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 372
      StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs

372
StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs

@ -51,48 +51,93 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private readonly ILiteDbContext liteDbContext; private readonly ILiteDbContext liteDbContext;
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private const int MaxModelsPerPage = 14; private const int MaxModelsPerPage = 14;
private LRUCache<int /* model id */, CheckpointBrowserCardViewModel> cache = new(50); private LRUCache<
int /* model id */
[ObservableProperty] private ObservableCollection<CheckpointBrowserCardViewModel>? modelCards; ,
[ObservableProperty] private DataGridCollectionView? modelCardsView; CheckpointBrowserCardViewModel
> cache = new(50);
[ObservableProperty] private string searchQuery = string.Empty;
[ObservableProperty] private bool showNsfw; [ObservableProperty]
[ObservableProperty] private bool showMainLoadingSpinner; private ObservableCollection<CheckpointBrowserCardViewModel>? modelCards;
[ObservableProperty] private CivitPeriod selectedPeriod = CivitPeriod.Month;
[ObservableProperty] private CivitSortMode sortMode = CivitSortMode.HighestRated; [ObservableProperty]
[ObservableProperty] private CivitModelType selectedModelType = CivitModelType.Checkpoint; private DataGridCollectionView? modelCardsView;
[ObservableProperty] private int currentPageNumber;
[ObservableProperty] private int totalPages; [ObservableProperty]
[ObservableProperty] private bool hasSearched; private string searchQuery = string.Empty;
[ObservableProperty] private bool canGoToNextPage;
[ObservableProperty] private bool canGoToPreviousPage; [ObservableProperty]
[ObservableProperty] private bool canGoToFirstPage; private bool showNsfw;
[ObservableProperty] private bool canGoToLastPage;
[ObservableProperty] private bool isIndeterminate; [ObservableProperty]
[ObservableProperty] private bool noResultsFound; private bool showMainLoadingSpinner;
[ObservableProperty] private string noResultsText = string.Empty;
[ObservableProperty] private string selectedBaseModelType = "All"; [ObservableProperty]
private CivitPeriod selectedPeriod = CivitPeriod.Month;
[ObservableProperty]
private CivitSortMode sortMode = CivitSortMode.HighestRated;
[ObservableProperty]
private CivitModelType selectedModelType = CivitModelType.Checkpoint;
[ObservableProperty]
private int currentPageNumber;
[ObservableProperty]
private int totalPages;
[ObservableProperty]
private bool hasSearched;
[ObservableProperty]
private bool canGoToNextPage;
[ObservableProperty]
private bool canGoToPreviousPage;
[ObservableProperty]
private bool canGoToFirstPage;
[ObservableProperty]
private bool canGoToLastPage;
[ObservableProperty]
private bool isIndeterminate;
[ObservableProperty]
private bool noResultsFound;
[ObservableProperty]
private string noResultsText = string.Empty;
[ObservableProperty]
private string selectedBaseModelType = "All";
private List<CheckpointBrowserCardViewModel> allModelCards = new(); private List<CheckpointBrowserCardViewModel> allModelCards = new();
public IEnumerable<CivitPeriod> AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
public IEnumerable<CivitSortMode> AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public IEnumerable<CivitModelType> AllModelTypes => Enum.GetValues(typeof(CivitModelType)) public IEnumerable<CivitPeriod> AllCivitPeriods =>
.Cast<CivitModelType>() Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0) public IEnumerable<CivitSortMode> AllSortModes =>
.OrderBy(t => t.ToString()); Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public IEnumerable<CivitModelType> AllModelTypes =>
Enum.GetValues(typeof(CivitModelType))
.Cast<CivitModelType>()
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString());
public List<string> BaseModelOptions => new() {"All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0"}; public List<string> BaseModelOptions =>
new() { "All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0" };
public CheckpointBrowserViewModel( public CheckpointBrowserViewModel(
ICivitApi civitApi, ICivitApi civitApi,
IDownloadService downloadService, IDownloadService downloadService,
ISettingsManager settingsManager, ISettingsManager settingsManager,
ServiceManager<ViewModelBase> dialogFactory, ServiceManager<ViewModelBase> dialogFactory,
ILiteDbContext liteDbContext, ILiteDbContext liteDbContext,
INotificationService notificationService) INotificationService notificationService
)
{ {
this.civitApi = civitApi; this.civitApi = civitApi;
this.downloadService = downloadService; this.downloadService = downloadService;
@ -117,27 +162,38 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
public override void OnLoaded() public override void OnLoaded()
{ {
if (Design.IsDesignMode) return; if (Design.IsDesignMode)
return;
var searchOptions = settingsManager.Settings.ModelSearchOptions; var searchOptions = settingsManager.Settings.ModelSearchOptions;
// Fix SelectedModelType if someone had selected the obsolete "Model" option // Fix SelectedModelType if someone had selected the obsolete "Model" option
if (searchOptions is {SelectedModelType: CivitModelType.Model}) if (searchOptions is { SelectedModelType: CivitModelType.Model })
{ {
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, SortMode, CivitModelType.Checkpoint, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
CivitModelType.Checkpoint,
SelectedBaseModelType
)
);
searchOptions = settingsManager.Settings.ModelSearchOptions; searchOptions = settingsManager.Settings.ModelSearchOptions;
} }
SelectedPeriod = searchOptions?.SelectedPeriod ?? CivitPeriod.Month; SelectedPeriod = searchOptions?.SelectedPeriod ?? CivitPeriod.Month;
SortMode = searchOptions?.SortMode ?? CivitSortMode.HighestRated; SortMode = searchOptions?.SortMode ?? CivitSortMode.HighestRated;
SelectedModelType = searchOptions?.SelectedModelType ?? CivitModelType.Checkpoint; SelectedModelType = searchOptions?.SelectedModelType ?? CivitModelType.Checkpoint;
SelectedBaseModelType = searchOptions?.SelectedBaseModelType ?? "All"; SelectedBaseModelType = searchOptions?.SelectedBaseModelType ?? "All";
ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled; ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled;
settingsManager.RelayPropertyFor(this, model => model.ShowNsfw, settingsManager.RelayPropertyFor(
settings => settings.ModelBrowserNsfwEnabled); this,
model => model.ShowNsfw,
settings => settings.ModelBrowserNsfwEnabled
);
} }
/// <summary> /// <summary>
@ -145,7 +201,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
/// </summary> /// </summary>
private bool FilterModelCardsPredicate(object? item) private bool FilterModelCardsPredicate(object? item)
{ {
if (item is not CheckpointBrowserCardViewModel card) return false; if (item is not CheckpointBrowserCardViewModel card)
return false;
return !card.CivitModel.Nsfw || ShowNsfw; return !card.CivitModel.Nsfw || ShowNsfw;
} }
@ -162,38 +219,52 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
var models = modelsResponse.Items; var models = modelsResponse.Items;
if (models is null) if (models is null)
{ {
Logger.Debug("CivitAI Query {Text} returned no results (in {Elapsed:F1} s)", Logger.Debug(
queryText, timer.Elapsed.TotalSeconds); "CivitAI Query {Text} returned no results (in {Elapsed:F1} s)",
queryText,
timer.Elapsed.TotalSeconds
);
return; return;
} }
Logger.Debug("CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)", Logger.Debug(
queryText, models.Count, timer.Elapsed.TotalSeconds); "CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)",
queryText,
models.Count,
timer.Elapsed.TotalSeconds
);
var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList(); var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList();
if (unknown.Any()) if (unknown.Any())
{ {
var names = unknown.Select(m => m.Name).ToList(); var names = unknown.Select(m => m.Name).ToList();
Logger.Warn("Excluded {Unknown} unknown model types: {Models}", unknown.Count, Logger.Warn(
names); "Excluded {Unknown} unknown model types: {Models}",
unknown.Count,
names
);
} }
// Filter out unknown model types and archived/taken-down models // Filter out unknown model types and archived/taken-down models
models = models.Where(m => m.Type.ConvertTo<SharedFolderType>() > 0) models = models
.Where(m => m.Mode == null).ToList(); .Where(m => m.Type.ConvertTo<SharedFolderType>() > 0)
.Where(m => m.Mode == null)
.ToList();
// Database update calls will invoke `OnModelsUpdated` // Database update calls will invoke `OnModelsUpdated`
// Add to database // Add to database
await liteDbContext.UpsertCivitModelAsync(models); await liteDbContext.UpsertCivitModelAsync(models);
// Add as cache entry // Add as cache entry
var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(new() var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(
{ new()
Id = ObjectHash.GetMd5Guid(request), {
InsertedAt = DateTimeOffset.UtcNow, Id = ObjectHash.GetMd5Guid(request),
Request = request, InsertedAt = DateTimeOffset.UtcNow,
Items = models, Request = request,
Metadata = modelsResponse.Metadata Items = models,
}); Metadata = modelsResponse.Metadata
}
);
if (cacheNew) if (cacheNew)
{ {
@ -207,26 +278,42 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
notificationService.Show(new Notification("Request to CivitAI timed out", notificationService.Show(
"Please try again in a few minutes")); new Notification(
"Request to CivitAI timed out",
"Please try again in a few minutes"
)
);
Logger.Warn($"CivitAI query timed out ({request})"); Logger.Warn($"CivitAI query timed out ({request})");
} }
catch (HttpRequestException e) catch (HttpRequestException e)
{ {
notificationService.Show(new Notification("CivitAI can't be reached right now", notificationService.Show(
"Please try again in a few minutes")); new Notification(
"CivitAI can't be reached right now",
"Please try again in a few minutes"
)
);
Logger.Warn(e, $"CivitAI query HttpRequestException ({request})"); Logger.Warn(e, $"CivitAI query HttpRequestException ({request})");
} }
catch (ApiException e) catch (ApiException e)
{ {
notificationService.Show(new Notification("CivitAI can't be reached right now", notificationService.Show(
"Please try again in a few minutes")); new Notification(
"CivitAI can't be reached right now",
"Please try again in a few minutes"
)
);
Logger.Warn(e, $"CivitAI query ApiException ({request})"); Logger.Warn(e, $"CivitAI query ApiException ({request})");
} }
catch (Exception e) catch (Exception e)
{ {
notificationService.Show(new Notification("CivitAI can't be reached right now", notificationService.Show(
$"Unknown exception during CivitAI query: {e.GetType().Name}")); new Notification(
"CivitAI can't be reached right now",
$"Unknown exception during CivitAI query: {e.GetType().Name}"
)
);
Logger.Error(e, $"CivitAI query unknown exception ({request})"); Logger.Error(e, $"CivitAI query unknown exception ({request})");
} }
finally finally
@ -235,11 +322,11 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
UpdateResultsText(); UpdateResultsText();
} }
} }
/// <summary> /// <summary>
/// Updates model cards using api response object. /// Updates model cards using api response object.
/// </summary> /// </summary>
private void UpdateModelCards(IEnumerable<CivitModel>? models, CivitMetadata? metadata) private void UpdateModelCards(IEnumerable<CivitModel>? models, CivitMetadata? metadata)
{ {
if (models is null) if (models is null)
{ {
@ -266,26 +353,29 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
vm.CivitModel = model; vm.CivitModel = model;
vm.OnDownloadStart = viewModel => vm.OnDownloadStart = viewModel =>
{ {
if (cache.Get(viewModel.CivitModel.Id) != null) return; if (cache.Get(viewModel.CivitModel.Id) != null)
return;
cache.Add(viewModel.CivitModel.Id, viewModel); cache.Add(viewModel.CivitModel.Id, viewModel);
}; };
return vm; return vm;
}); });
return newCard; return newCard;
}).ToList(); })
.ToList();
allModelCards = updateCards; allModelCards = updateCards;
var filteredCards = updateCards.Where(FilterModelCardsPredicate); var filteredCards = updateCards.Where(FilterModelCardsPredicate);
if (SortMode == CivitSortMode.Installed) if (SortMode == CivitSortMode.Installed)
{ {
filteredCards = filteredCards = filteredCards.OrderByDescending(
filteredCards.OrderByDescending(x => x.UpdateCardText == "Update Available"); x => x.UpdateCardText == "Update Available"
);
} }
ModelCards =new ObservableCollection<CheckpointBrowserCardViewModel>(filteredCards); ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(filteredCards);
} }
TotalPages = metadata?.TotalPages ?? 1; TotalPages = metadata?.TotalPages ?? 1;
CanGoToFirstPage = CurrentPageNumber != 1; CanGoToFirstPage = CurrentPageNumber != 1;
@ -304,14 +394,14 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private async Task SearchModels() private async Task SearchModels()
{ {
var timer = Stopwatch.StartNew(); var timer = Stopwatch.StartNew();
if (SearchQuery != previousSearchQuery) if (SearchQuery != previousSearchQuery)
{ {
// Reset page number // Reset page number
CurrentPageNumber = 1; CurrentPageNumber = 1;
previousSearchQuery = SearchQuery; previousSearchQuery = SearchQuery;
} }
// Build request // Build request
var modelRequest = new CivitModelsRequest var modelRequest = new CivitModelsRequest
{ {
@ -334,10 +424,10 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
{ {
modelRequest.Query = SearchQuery; modelRequest.Query = SearchQuery;
} }
if (SelectedModelType != CivitModelType.All) if (SelectedModelType != CivitModelType.All)
{ {
modelRequest.Types = new[] {SelectedModelType}; modelRequest.Types = new[] { SelectedModelType };
} }
if (SelectedBaseModelType != "All") if (SelectedBaseModelType != "All")
@ -347,9 +437,9 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
if (SortMode == CivitSortMode.Installed) if (SortMode == CivitSortMode.Installed)
{ {
var connectedModels = var connectedModels = CheckpointFile
CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory) .GetAllCheckpointFiles(settingsManager.ModelsDirectory)
.Where(c => c.IsConnectedModel); .Where(c => c.IsConnectedModel);
if (SelectedModelType != CivitModelType.All) if (SelectedModelType != CivitModelType.All)
{ {
@ -358,24 +448,34 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
modelRequest = new CivitModelsRequest modelRequest = new CivitModelsRequest
{ {
CommaSeparatedModelIds = string.Join(",", CommaSeparatedModelIds = string.Join(
connectedModels.Select(c => c.ConnectedModel!.ModelId).GroupBy(m => m) ",",
.Select(g => g.First())), connectedModels
Types = SelectedModelType == CivitModelType.All ? null : new[] {SelectedModelType} .Select(c => c.ConnectedModel!.ModelId)
.GroupBy(m => m)
.Select(g => g.First())
),
Types =
SelectedModelType == CivitModelType.All ? null : new[] { SelectedModelType },
Page = CurrentPageNumber,
}; };
} }
// See if query is cached // See if query is cached
var cachedQuery = await liteDbContext.CivitModelQueryCache var cachedQuery = await liteDbContext.CivitModelQueryCache
.IncludeAll() .IncludeAll()
.FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest)); .FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest));
// If cached, update model cards // If cached, update model cards
if (cachedQuery is not null) if (cachedQuery is not null)
{ {
var elapsed = timer.Elapsed; var elapsed = timer.Elapsed;
Logger.Debug("Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)", Logger.Debug(
SearchQuery, modelRequest.GetHashCode(), elapsed.TotalSeconds); "Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)",
SearchQuery,
modelRequest.GetHashCode(),
elapsed.TotalSeconds
);
UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata); UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata);
// Start remote query (background mode) // Start remote query (background mode)
@ -386,7 +486,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
CivitModelQuery(modelRequest).SafeFireAndForget(); CivitModelQuery(modelRequest).SafeFireAndForget();
Logger.Debug( Logger.Debug(
"Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query", "Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query",
timeSinceCache.Value.TotalSeconds); timeSinceCache.Value.TotalSeconds
);
} }
} }
else else
@ -395,7 +496,7 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
ShowMainLoadingSpinner = true; ShowMainLoadingSpinner = true;
await CivitModelQuery(modelRequest); await CivitModelQuery(modelRequest);
} }
UpdateResultsText(); UpdateResultsText();
} }
@ -406,17 +507,17 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
public void PreviousPage() public void PreviousPage()
{ {
if (CurrentPageNumber == 1) if (CurrentPageNumber == 1)
return; return;
CurrentPageNumber--; CurrentPageNumber--;
} }
public void NextPage() public void NextPage()
{ {
if (CurrentPageNumber == TotalPages) if (CurrentPageNumber == TotalPages)
return; return;
CurrentPageNumber++; CurrentPageNumber++;
} }
@ -429,46 +530,75 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
{ {
settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value); settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value);
// ModelCardsView?.Refresh(); // ModelCardsView?.Refresh();
var updateCards = allModelCards var updateCards = allModelCards.Where(FilterModelCardsPredicate);
.Where(FilterModelCardsPredicate);
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(updateCards); ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(updateCards);
if (!HasSearched) return; if (!HasSearched)
return;
UpdateResultsText(); UpdateResultsText();
} }
partial void OnSelectedPeriodChanged(CivitPeriod value) partial void OnSelectedPeriodChanged(CivitPeriod value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
value, SortMode, SelectedModelType, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
value,
SortMode,
SelectedModelType,
SelectedBaseModelType
)
);
} }
partial void OnSortModeChanged(CivitSortMode value) partial void OnSortModeChanged(CivitSortMode value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, value, SelectedModelType, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
value,
SelectedModelType,
SelectedBaseModelType
)
);
} }
partial void OnSelectedModelTypeChanged(CivitModelType value) partial void OnSelectedModelTypeChanged(CivitModelType value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, SortMode, value, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
value,
SelectedBaseModelType
)
);
} }
partial void OnSelectedBaseModelTypeChanged(string value) partial void OnSelectedBaseModelTypeChanged(string value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, SortMode, SelectedModelType, value)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
SelectedModelType,
value
)
);
} }
private async Task TrySearchAgain(bool shouldUpdatePageNumber = true) private async Task TrySearchAgain(bool shouldUpdatePageNumber = true)
{ {
if (!HasSearched) return; if (!HasSearched)
return;
ModelCards?.Clear(); ModelCards?.Clear();
if (shouldUpdatePageNumber) if (shouldUpdatePageNumber)
@ -483,11 +613,13 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private void UpdateResultsText() private void UpdateResultsText()
{ {
NoResultsFound = ModelCards?.Count <= 0; NoResultsFound = ModelCards?.Count <= 0;
NoResultsText = allModelCards.Count > 0 NoResultsText =
? $"{allModelCards.Count} results hidden by filters" allModelCards.Count > 0
: "No results found"; ? $"{allModelCards.Count} results hidden by filters"
: "No results found";
} }
public override string Title => "Model Browser"; public override string Title => "Model Browser";
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true }; public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true };
} }

Loading…
Cancel
Save