Browse Source

Merge branch 'main' into inference

pull/165/head
Ionite 1 year ago committed by GitHub
parent
commit
4c68fd5bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 296
      StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs

296
StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs

@ -51,40 +51,84 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private readonly ILiteDbContext liteDbContext; private readonly ILiteDbContext liteDbContext;
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private const int MaxModelsPerPage = 14; private const int MaxModelsPerPage = 14;
private LRUCache<int /* model id */, CheckpointBrowserCardViewModel> cache = new(50); private LRUCache<
int /* model id */
[ObservableProperty] private ObservableCollection<CheckpointBrowserCardViewModel>? modelCards; ,
[ObservableProperty] private DataGridCollectionView? modelCardsView; CheckpointBrowserCardViewModel
> cache = new(50);
[ObservableProperty] private string searchQuery = string.Empty;
[ObservableProperty] private bool showNsfw; [ObservableProperty]
[ObservableProperty] private bool showMainLoadingSpinner; private ObservableCollection<CheckpointBrowserCardViewModel>? modelCards;
[ObservableProperty] private CivitPeriod selectedPeriod = CivitPeriod.Month;
[ObservableProperty] private CivitSortMode sortMode = CivitSortMode.HighestRated; [ObservableProperty]
[ObservableProperty] private CivitModelType selectedModelType = CivitModelType.Checkpoint; private DataGridCollectionView? modelCardsView;
[ObservableProperty] private int currentPageNumber;
[ObservableProperty] private int totalPages; [ObservableProperty]
[ObservableProperty] private bool hasSearched; private string searchQuery = string.Empty;
[ObservableProperty] private bool canGoToNextPage;
[ObservableProperty] private bool canGoToPreviousPage; [ObservableProperty]
[ObservableProperty] private bool canGoToFirstPage; private bool showNsfw;
[ObservableProperty] private bool canGoToLastPage;
[ObservableProperty] private bool isIndeterminate; [ObservableProperty]
[ObservableProperty] private bool noResultsFound; private bool showMainLoadingSpinner;
[ObservableProperty] private string noResultsText = string.Empty;
[ObservableProperty] private string selectedBaseModelType = "All"; [ObservableProperty]
private CivitPeriod selectedPeriod = CivitPeriod.Month;
[ObservableProperty]
private CivitSortMode sortMode = CivitSortMode.HighestRated;
[ObservableProperty]
private CivitModelType selectedModelType = CivitModelType.Checkpoint;
[ObservableProperty]
private int currentPageNumber;
[ObservableProperty]
private int totalPages;
[ObservableProperty]
private bool hasSearched;
[ObservableProperty]
private bool canGoToNextPage;
[ObservableProperty]
private bool canGoToPreviousPage;
[ObservableProperty]
private bool canGoToFirstPage;
[ObservableProperty]
private bool canGoToLastPage;
[ObservableProperty]
private bool isIndeterminate;
[ObservableProperty]
private bool noResultsFound;
[ObservableProperty]
private string noResultsText = string.Empty;
[ObservableProperty]
private string selectedBaseModelType = "All";
private List<CheckpointBrowserCardViewModel> allModelCards = new(); private List<CheckpointBrowserCardViewModel> allModelCards = new();
public IEnumerable<CivitPeriod> AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>(); public IEnumerable<CivitPeriod> AllCivitPeriods =>
public IEnumerable<CivitSortMode> AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>(); Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
public IEnumerable<CivitSortMode> AllSortModes =>
Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public IEnumerable<CivitModelType> AllModelTypes => Enum.GetValues(typeof(CivitModelType)) public IEnumerable<CivitModelType> AllModelTypes =>
Enum.GetValues(typeof(CivitModelType))
.Cast<CivitModelType>() .Cast<CivitModelType>()
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0) .Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString()); .OrderBy(t => t.ToString());
public List<string> BaseModelOptions => new() {"All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0"}; public List<string> BaseModelOptions =>
new() { "All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0" };
public CheckpointBrowserViewModel( public CheckpointBrowserViewModel(
ICivitApi civitApi, ICivitApi civitApi,
@ -92,7 +136,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
ISettingsManager settingsManager, ISettingsManager settingsManager,
ServiceManager<ViewModelBase> dialogFactory, ServiceManager<ViewModelBase> dialogFactory,
ILiteDbContext liteDbContext, ILiteDbContext liteDbContext,
INotificationService notificationService) INotificationService notificationService
)
{ {
this.civitApi = civitApi; this.civitApi = civitApi;
this.downloadService = downloadService; this.downloadService = downloadService;
@ -117,15 +162,23 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
public override void OnLoaded() public override void OnLoaded()
{ {
if (Design.IsDesignMode) return; if (Design.IsDesignMode)
return;
var searchOptions = settingsManager.Settings.ModelSearchOptions; var searchOptions = settingsManager.Settings.ModelSearchOptions;
// Fix SelectedModelType if someone had selected the obsolete "Model" option // Fix SelectedModelType if someone had selected the obsolete "Model" option
if (searchOptions is {SelectedModelType: CivitModelType.Model}) if (searchOptions is { SelectedModelType: CivitModelType.Model })
{ {
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, SortMode, CivitModelType.Checkpoint, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
CivitModelType.Checkpoint,
SelectedBaseModelType
)
);
searchOptions = settingsManager.Settings.ModelSearchOptions; searchOptions = settingsManager.Settings.ModelSearchOptions;
} }
@ -136,8 +189,11 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled; ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled;
settingsManager.RelayPropertyFor(this, model => model.ShowNsfw, settingsManager.RelayPropertyFor(
settings => settings.ModelBrowserNsfwEnabled); this,
model => model.ShowNsfw,
settings => settings.ModelBrowserNsfwEnabled
);
} }
/// <summary> /// <summary>
@ -145,7 +201,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
/// </summary> /// </summary>
private bool FilterModelCardsPredicate(object? item) private bool FilterModelCardsPredicate(object? item)
{ {
if (item is not CheckpointBrowserCardViewModel card) return false; if (item is not CheckpointBrowserCardViewModel card)
return false;
return !card.CivitModel.Nsfw || ShowNsfw; return !card.CivitModel.Nsfw || ShowNsfw;
} }
@ -162,38 +219,52 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
var models = modelsResponse.Items; var models = modelsResponse.Items;
if (models is null) if (models is null)
{ {
Logger.Debug("CivitAI Query {Text} returned no results (in {Elapsed:F1} s)", Logger.Debug(
queryText, timer.Elapsed.TotalSeconds); "CivitAI Query {Text} returned no results (in {Elapsed:F1} s)",
queryText,
timer.Elapsed.TotalSeconds
);
return; return;
} }
Logger.Debug("CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)", Logger.Debug(
queryText, models.Count, timer.Elapsed.TotalSeconds); "CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)",
queryText,
models.Count,
timer.Elapsed.TotalSeconds
);
var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList(); var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList();
if (unknown.Any()) if (unknown.Any())
{ {
var names = unknown.Select(m => m.Name).ToList(); var names = unknown.Select(m => m.Name).ToList();
Logger.Warn("Excluded {Unknown} unknown model types: {Models}", unknown.Count, Logger.Warn(
names); "Excluded {Unknown} unknown model types: {Models}",
unknown.Count,
names
);
} }
// Filter out unknown model types and archived/taken-down models // Filter out unknown model types and archived/taken-down models
models = models.Where(m => m.Type.ConvertTo<SharedFolderType>() > 0) models = models
.Where(m => m.Mode == null).ToList(); .Where(m => m.Type.ConvertTo<SharedFolderType>() > 0)
.Where(m => m.Mode == null)
.ToList();
// Database update calls will invoke `OnModelsUpdated` // Database update calls will invoke `OnModelsUpdated`
// Add to database // Add to database
await liteDbContext.UpsertCivitModelAsync(models); await liteDbContext.UpsertCivitModelAsync(models);
// Add as cache entry // Add as cache entry
var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(new() var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(
new()
{ {
Id = ObjectHash.GetMd5Guid(request), Id = ObjectHash.GetMd5Guid(request),
InsertedAt = DateTimeOffset.UtcNow, InsertedAt = DateTimeOffset.UtcNow,
Request = request, Request = request,
Items = models, Items = models,
Metadata = modelsResponse.Metadata Metadata = modelsResponse.Metadata
}); }
);
if (cacheNew) if (cacheNew)
{ {
@ -207,26 +278,42 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
notificationService.Show(new Notification("Request to CivitAI timed out", notificationService.Show(
"Please try again in a few minutes")); new Notification(
"Request to CivitAI timed out",
"Please try again in a few minutes"
)
);
Logger.Warn($"CivitAI query timed out ({request})"); Logger.Warn($"CivitAI query timed out ({request})");
} }
catch (HttpRequestException e) catch (HttpRequestException e)
{ {
notificationService.Show(new Notification("CivitAI can't be reached right now", notificationService.Show(
"Please try again in a few minutes")); new Notification(
"CivitAI can't be reached right now",
"Please try again in a few minutes"
)
);
Logger.Warn(e, $"CivitAI query HttpRequestException ({request})"); Logger.Warn(e, $"CivitAI query HttpRequestException ({request})");
} }
catch (ApiException e) catch (ApiException e)
{ {
notificationService.Show(new Notification("CivitAI can't be reached right now", notificationService.Show(
"Please try again in a few minutes")); new Notification(
"CivitAI can't be reached right now",
"Please try again in a few minutes"
)
);
Logger.Warn(e, $"CivitAI query ApiException ({request})"); Logger.Warn(e, $"CivitAI query ApiException ({request})");
} }
catch (Exception e) catch (Exception e)
{ {
notificationService.Show(new Notification("CivitAI can't be reached right now", notificationService.Show(
$"Unknown exception during CivitAI query: {e.GetType().Name}")); new Notification(
"CivitAI can't be reached right now",
$"Unknown exception during CivitAI query: {e.GetType().Name}"
)
);
Logger.Error(e, $"CivitAI query unknown exception ({request})"); Logger.Error(e, $"CivitAI query unknown exception ({request})");
} }
finally finally
@ -266,7 +353,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
vm.CivitModel = model; vm.CivitModel = model;
vm.OnDownloadStart = viewModel => vm.OnDownloadStart = viewModel =>
{ {
if (cache.Get(viewModel.CivitModel.Id) != null) return; if (cache.Get(viewModel.CivitModel.Id) != null)
return;
cache.Add(viewModel.CivitModel.Id, viewModel); cache.Add(viewModel.CivitModel.Id, viewModel);
}; };
@ -274,18 +362,20 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
}); });
return newCard; return newCard;
}).ToList(); })
.ToList();
allModelCards = updateCards; allModelCards = updateCards;
var filteredCards = updateCards.Where(FilterModelCardsPredicate); var filteredCards = updateCards.Where(FilterModelCardsPredicate);
if (SortMode == CivitSortMode.Installed) if (SortMode == CivitSortMode.Installed)
{ {
filteredCards = filteredCards = filteredCards.OrderByDescending(
filteredCards.OrderByDescending(x => x.UpdateCardText == "Update Available"); x => x.UpdateCardText == "Update Available"
);
} }
ModelCards =new ObservableCollection<CheckpointBrowserCardViewModel>(filteredCards); ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(filteredCards);
} }
TotalPages = metadata?.TotalPages ?? 1; TotalPages = metadata?.TotalPages ?? 1;
CanGoToFirstPage = CurrentPageNumber != 1; CanGoToFirstPage = CurrentPageNumber != 1;
@ -337,7 +427,7 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
if (SelectedModelType != CivitModelType.All) if (SelectedModelType != CivitModelType.All)
{ {
modelRequest.Types = new[] {SelectedModelType}; modelRequest.Types = new[] { SelectedModelType };
} }
if (SelectedBaseModelType != "All") if (SelectedBaseModelType != "All")
@ -347,8 +437,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
if (SortMode == CivitSortMode.Installed) if (SortMode == CivitSortMode.Installed)
{ {
var connectedModels = var connectedModels = CheckpointFile
CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory) .GetAllCheckpointFiles(settingsManager.ModelsDirectory)
.Where(c => c.IsConnectedModel); .Where(c => c.IsConnectedModel);
if (SelectedModelType != CivitModelType.All) if (SelectedModelType != CivitModelType.All)
@ -358,10 +448,16 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
modelRequest = new CivitModelsRequest modelRequest = new CivitModelsRequest
{ {
CommaSeparatedModelIds = string.Join(",", CommaSeparatedModelIds = string.Join(
connectedModels.Select(c => c.ConnectedModel!.ModelId).GroupBy(m => m) ",",
.Select(g => g.First())), connectedModels
Types = SelectedModelType == CivitModelType.All ? null : new[] {SelectedModelType} .Select(c => c.ConnectedModel!.ModelId)
.GroupBy(m => m)
.Select(g => g.First())
),
Types =
SelectedModelType == CivitModelType.All ? null : new[] { SelectedModelType },
Page = CurrentPageNumber,
}; };
} }
@ -374,8 +470,12 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
if (cachedQuery is not null) if (cachedQuery is not null)
{ {
var elapsed = timer.Elapsed; var elapsed = timer.Elapsed;
Logger.Debug("Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)", Logger.Debug(
SearchQuery, modelRequest.GetHashCode(), elapsed.TotalSeconds); "Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)",
SearchQuery,
modelRequest.GetHashCode(),
elapsed.TotalSeconds
);
UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata); UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata);
// Start remote query (background mode) // Start remote query (background mode)
@ -386,7 +486,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
CivitModelQuery(modelRequest).SafeFireAndForget(); CivitModelQuery(modelRequest).SafeFireAndForget();
Logger.Debug( Logger.Debug(
"Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query", "Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query",
timeSinceCache.Value.TotalSeconds); timeSinceCache.Value.TotalSeconds
);
} }
} }
else else
@ -429,11 +530,11 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
{ {
settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value); settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value);
// ModelCardsView?.Refresh(); // ModelCardsView?.Refresh();
var updateCards = allModelCards var updateCards = allModelCards.Where(FilterModelCardsPredicate);
.Where(FilterModelCardsPredicate);
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(updateCards); ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(updateCards);
if (!HasSearched) return; if (!HasSearched)
return;
UpdateResultsText(); UpdateResultsText();
} }
@ -441,34 +542,63 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
partial void OnSelectedPeriodChanged(CivitPeriod value) partial void OnSelectedPeriodChanged(CivitPeriod value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
value, SortMode, SelectedModelType, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
value,
SortMode,
SelectedModelType,
SelectedBaseModelType
)
);
} }
partial void OnSortModeChanged(CivitSortMode value) partial void OnSortModeChanged(CivitSortMode value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, value, SelectedModelType, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
value,
SelectedModelType,
SelectedBaseModelType
)
);
} }
partial void OnSelectedModelTypeChanged(CivitModelType value) partial void OnSelectedModelTypeChanged(CivitModelType value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, SortMode, value, SelectedBaseModelType)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
value,
SelectedBaseModelType
)
);
} }
partial void OnSelectedBaseModelTypeChanged(string value) partial void OnSelectedBaseModelTypeChanged(string value)
{ {
TrySearchAgain().SafeFireAndForget(); TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions( settingsManager.Transaction(
SelectedPeriod, SortMode, SelectedModelType, value)); s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
SelectedModelType,
value
)
);
} }
private async Task TrySearchAgain(bool shouldUpdatePageNumber = true) private async Task TrySearchAgain(bool shouldUpdatePageNumber = true)
{ {
if (!HasSearched) return; if (!HasSearched)
return;
ModelCards?.Clear(); ModelCards?.Clear();
if (shouldUpdatePageNumber) if (shouldUpdatePageNumber)
@ -483,11 +613,13 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private void UpdateResultsText() private void UpdateResultsText()
{ {
NoResultsFound = ModelCards?.Count <= 0; NoResultsFound = ModelCards?.Count <= 0;
NoResultsText = allModelCards.Count > 0 NoResultsText =
allModelCards.Count > 0
? $"{allModelCards.Count} results hidden by filters" ? $"{allModelCards.Count} results hidden by filters"
: "No results found"; : "No results found";
} }
public override string Title => "Model Browser"; public override string Title => "Model Browser";
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true }; public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true };
} }

Loading…
Cancel
Save