Browse Source

Merge branch 'main' into inference

pull/165/head
Ionite 1 year ago committed by GitHub
parent
commit
4c68fd5bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 372
      StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs

372
StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs

@ -51,48 +51,93 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private readonly ILiteDbContext liteDbContext;
private readonly INotificationService notificationService;
private const int MaxModelsPerPage = 14;
private LRUCache<int /* model id */, CheckpointBrowserCardViewModel> cache = new(50);
[ObservableProperty] private ObservableCollection<CheckpointBrowserCardViewModel>? modelCards;
[ObservableProperty] private DataGridCollectionView? modelCardsView;
[ObservableProperty] private string searchQuery = string.Empty;
[ObservableProperty] private bool showNsfw;
[ObservableProperty] private bool showMainLoadingSpinner;
[ObservableProperty] private CivitPeriod selectedPeriod = CivitPeriod.Month;
[ObservableProperty] private CivitSortMode sortMode = CivitSortMode.HighestRated;
[ObservableProperty] private CivitModelType selectedModelType = CivitModelType.Checkpoint;
[ObservableProperty] private int currentPageNumber;
[ObservableProperty] private int totalPages;
[ObservableProperty] private bool hasSearched;
[ObservableProperty] private bool canGoToNextPage;
[ObservableProperty] private bool canGoToPreviousPage;
[ObservableProperty] private bool canGoToFirstPage;
[ObservableProperty] private bool canGoToLastPage;
[ObservableProperty] private bool isIndeterminate;
[ObservableProperty] private bool noResultsFound;
[ObservableProperty] private string noResultsText = string.Empty;
[ObservableProperty] private string selectedBaseModelType = "All";
private LRUCache<
int /* model id */
,
CheckpointBrowserCardViewModel
> cache = new(50);
[ObservableProperty]
private ObservableCollection<CheckpointBrowserCardViewModel>? modelCards;
[ObservableProperty]
private DataGridCollectionView? modelCardsView;
[ObservableProperty]
private string searchQuery = string.Empty;
[ObservableProperty]
private bool showNsfw;
[ObservableProperty]
private bool showMainLoadingSpinner;
[ObservableProperty]
private CivitPeriod selectedPeriod = CivitPeriod.Month;
[ObservableProperty]
private CivitSortMode sortMode = CivitSortMode.HighestRated;
[ObservableProperty]
private CivitModelType selectedModelType = CivitModelType.Checkpoint;
[ObservableProperty]
private int currentPageNumber;
[ObservableProperty]
private int totalPages;
[ObservableProperty]
private bool hasSearched;
[ObservableProperty]
private bool canGoToNextPage;
[ObservableProperty]
private bool canGoToPreviousPage;
[ObservableProperty]
private bool canGoToFirstPage;
[ObservableProperty]
private bool canGoToLastPage;
[ObservableProperty]
private bool isIndeterminate;
[ObservableProperty]
private bool noResultsFound;
[ObservableProperty]
private string noResultsText = string.Empty;
[ObservableProperty]
private string selectedBaseModelType = "All";
private List<CheckpointBrowserCardViewModel> allModelCards = new();
public IEnumerable<CivitPeriod> AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
public IEnumerable<CivitSortMode> AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public IEnumerable<CivitModelType> AllModelTypes => Enum.GetValues(typeof(CivitModelType))
.Cast<CivitModelType>()
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString());
public IEnumerable<CivitPeriod> AllCivitPeriods =>
Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
public IEnumerable<CivitSortMode> AllSortModes =>
Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public IEnumerable<CivitModelType> AllModelTypes =>
Enum.GetValues(typeof(CivitModelType))
.Cast<CivitModelType>()
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString());
public List<string> BaseModelOptions => new() {"All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0"};
public List<string> BaseModelOptions =>
new() { "All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0" };
public CheckpointBrowserViewModel(
ICivitApi civitApi,
IDownloadService downloadService,
ICivitApi civitApi,
IDownloadService downloadService,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> dialogFactory,
ILiteDbContext liteDbContext,
INotificationService notificationService)
INotificationService notificationService
)
{
this.civitApi = civitApi;
this.downloadService = downloadService;
@ -117,27 +162,38 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
public override void OnLoaded()
{
if (Design.IsDesignMode) return;
if (Design.IsDesignMode)
return;
var searchOptions = settingsManager.Settings.ModelSearchOptions;
// Fix SelectedModelType if someone had selected the obsolete "Model" option
if (searchOptions is {SelectedModelType: CivitModelType.Model})
if (searchOptions is { SelectedModelType: CivitModelType.Model })
{
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod, SortMode, CivitModelType.Checkpoint, SelectedBaseModelType));
settingsManager.Transaction(
s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
CivitModelType.Checkpoint,
SelectedBaseModelType
)
);
searchOptions = settingsManager.Settings.ModelSearchOptions;
}
SelectedPeriod = searchOptions?.SelectedPeriod ?? CivitPeriod.Month;
SortMode = searchOptions?.SortMode ?? CivitSortMode.HighestRated;
SelectedModelType = searchOptions?.SelectedModelType ?? CivitModelType.Checkpoint;
SelectedBaseModelType = searchOptions?.SelectedBaseModelType ?? "All";
ShowNsfw = settingsManager.Settings.ModelBrowserNsfwEnabled;
settingsManager.RelayPropertyFor(this, model => model.ShowNsfw,
settings => settings.ModelBrowserNsfwEnabled);
settingsManager.RelayPropertyFor(
this,
model => model.ShowNsfw,
settings => settings.ModelBrowserNsfwEnabled
);
}
/// <summary>
@ -145,7 +201,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
/// </summary>
private bool FilterModelCardsPredicate(object? item)
{
if (item is not CheckpointBrowserCardViewModel card) return false;
if (item is not CheckpointBrowserCardViewModel card)
return false;
return !card.CivitModel.Nsfw || ShowNsfw;
}
@ -162,38 +219,52 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
var models = modelsResponse.Items;
if (models is null)
{
Logger.Debug("CivitAI Query {Text} returned no results (in {Elapsed:F1} s)",
queryText, timer.Elapsed.TotalSeconds);
Logger.Debug(
"CivitAI Query {Text} returned no results (in {Elapsed:F1} s)",
queryText,
timer.Elapsed.TotalSeconds
);
return;
}
Logger.Debug("CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)",
queryText, models.Count, timer.Elapsed.TotalSeconds);
Logger.Debug(
"CivitAI Query {Text} returned {Results} results (in {Elapsed:F1} s)",
queryText,
models.Count,
timer.Elapsed.TotalSeconds
);
var unknown = models.Where(m => m.Type == CivitModelType.Unknown).ToList();
if (unknown.Any())
{
var names = unknown.Select(m => m.Name).ToList();
Logger.Warn("Excluded {Unknown} unknown model types: {Models}", unknown.Count,
names);
Logger.Warn(
"Excluded {Unknown} unknown model types: {Models}",
unknown.Count,
names
);
}
// Filter out unknown model types and archived/taken-down models
models = models.Where(m => m.Type.ConvertTo<SharedFolderType>() > 0)
.Where(m => m.Mode == null).ToList();
models = models
.Where(m => m.Type.ConvertTo<SharedFolderType>() > 0)
.Where(m => m.Mode == null)
.ToList();
// Database update calls will invoke `OnModelsUpdated`
// Add to database
await liteDbContext.UpsertCivitModelAsync(models);
// Add as cache entry
var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(new()
{
Id = ObjectHash.GetMd5Guid(request),
InsertedAt = DateTimeOffset.UtcNow,
Request = request,
Items = models,
Metadata = modelsResponse.Metadata
});
var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(
new()
{
Id = ObjectHash.GetMd5Guid(request),
InsertedAt = DateTimeOffset.UtcNow,
Request = request,
Items = models,
Metadata = modelsResponse.Metadata
}
);
if (cacheNew)
{
@ -207,26 +278,42 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
}
catch (OperationCanceledException)
{
notificationService.Show(new Notification("Request to CivitAI timed out",
"Please try again in a few minutes"));
notificationService.Show(
new Notification(
"Request to CivitAI timed out",
"Please try again in a few minutes"
)
);
Logger.Warn($"CivitAI query timed out ({request})");
}
catch (HttpRequestException e)
{
notificationService.Show(new Notification("CivitAI can't be reached right now",
"Please try again in a few minutes"));
notificationService.Show(
new Notification(
"CivitAI can't be reached right now",
"Please try again in a few minutes"
)
);
Logger.Warn(e, $"CivitAI query HttpRequestException ({request})");
}
catch (ApiException e)
{
notificationService.Show(new Notification("CivitAI can't be reached right now",
"Please try again in a few minutes"));
notificationService.Show(
new Notification(
"CivitAI can't be reached right now",
"Please try again in a few minutes"
)
);
Logger.Warn(e, $"CivitAI query ApiException ({request})");
}
catch (Exception e)
{
notificationService.Show(new Notification("CivitAI can't be reached right now",
$"Unknown exception during CivitAI query: {e.GetType().Name}"));
notificationService.Show(
new Notification(
"CivitAI can't be reached right now",
$"Unknown exception during CivitAI query: {e.GetType().Name}"
)
);
Logger.Error(e, $"CivitAI query unknown exception ({request})");
}
finally
@ -235,11 +322,11 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
UpdateResultsText();
}
}
/// <summary>
/// Updates model cards using api response object.
/// </summary>
private void UpdateModelCards(IEnumerable<CivitModel>? models, CivitMetadata? metadata)
private void UpdateModelCards(IEnumerable<CivitModel>? models, CivitMetadata? metadata)
{
if (models is null)
{
@ -266,26 +353,29 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
vm.CivitModel = model;
vm.OnDownloadStart = viewModel =>
{
if (cache.Get(viewModel.CivitModel.Id) != null) return;
if (cache.Get(viewModel.CivitModel.Id) != null)
return;
cache.Add(viewModel.CivitModel.Id, viewModel);
};
return vm;
});
return newCard;
}).ToList();
})
.ToList();
allModelCards = updateCards;
var filteredCards = updateCards.Where(FilterModelCardsPredicate);
if (SortMode == CivitSortMode.Installed)
{
filteredCards =
filteredCards.OrderByDescending(x => x.UpdateCardText == "Update Available");
filteredCards = filteredCards.OrderByDescending(
x => x.UpdateCardText == "Update Available"
);
}
ModelCards =new ObservableCollection<CheckpointBrowserCardViewModel>(filteredCards);
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(filteredCards);
}
TotalPages = metadata?.TotalPages ?? 1;
CanGoToFirstPage = CurrentPageNumber != 1;
@ -304,14 +394,14 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private async Task SearchModels()
{
var timer = Stopwatch.StartNew();
if (SearchQuery != previousSearchQuery)
{
// Reset page number
CurrentPageNumber = 1;
previousSearchQuery = SearchQuery;
}
// Build request
var modelRequest = new CivitModelsRequest
{
@ -334,10 +424,10 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
{
modelRequest.Query = SearchQuery;
}
if (SelectedModelType != CivitModelType.All)
{
modelRequest.Types = new[] {SelectedModelType};
modelRequest.Types = new[] { SelectedModelType };
}
if (SelectedBaseModelType != "All")
@ -347,9 +437,9 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
if (SortMode == CivitSortMode.Installed)
{
var connectedModels =
CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory)
.Where(c => c.IsConnectedModel);
var connectedModels = CheckpointFile
.GetAllCheckpointFiles(settingsManager.ModelsDirectory)
.Where(c => c.IsConnectedModel);
if (SelectedModelType != CivitModelType.All)
{
@ -358,24 +448,34 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
modelRequest = new CivitModelsRequest
{
CommaSeparatedModelIds = string.Join(",",
connectedModels.Select(c => c.ConnectedModel!.ModelId).GroupBy(m => m)
.Select(g => g.First())),
Types = SelectedModelType == CivitModelType.All ? null : new[] {SelectedModelType}
CommaSeparatedModelIds = string.Join(
",",
connectedModels
.Select(c => c.ConnectedModel!.ModelId)
.GroupBy(m => m)
.Select(g => g.First())
),
Types =
SelectedModelType == CivitModelType.All ? null : new[] { SelectedModelType },
Page = CurrentPageNumber,
};
}
// See if query is cached
var cachedQuery = await liteDbContext.CivitModelQueryCache
.IncludeAll()
.FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest));
// If cached, update model cards
if (cachedQuery is not null)
{
var elapsed = timer.Elapsed;
Logger.Debug("Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)",
SearchQuery, modelRequest.GetHashCode(), elapsed.TotalSeconds);
Logger.Debug(
"Using cached query for {Text} [{RequestHash}] (in {Elapsed:F1} s)",
SearchQuery,
modelRequest.GetHashCode(),
elapsed.TotalSeconds
);
UpdateModelCards(cachedQuery.Items, cachedQuery.Metadata);
// Start remote query (background mode)
@ -386,7 +486,8 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
CivitModelQuery(modelRequest).SafeFireAndForget();
Logger.Debug(
"Cached query was more than 2 minutes ago ({Seconds:F0} s), updating cache with remote query",
timeSinceCache.Value.TotalSeconds);
timeSinceCache.Value.TotalSeconds
);
}
}
else
@ -395,7 +496,7 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
ShowMainLoadingSpinner = true;
await CivitModelQuery(modelRequest);
}
UpdateResultsText();
}
@ -406,17 +507,17 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
public void PreviousPage()
{
if (CurrentPageNumber == 1)
if (CurrentPageNumber == 1)
return;
CurrentPageNumber--;
}
public void NextPage()
{
if (CurrentPageNumber == TotalPages)
if (CurrentPageNumber == TotalPages)
return;
CurrentPageNumber++;
}
@ -429,46 +530,75 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
{
settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value);
// ModelCardsView?.Refresh();
var updateCards = allModelCards
.Where(FilterModelCardsPredicate);
var updateCards = allModelCards.Where(FilterModelCardsPredicate);
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(updateCards);
if (!HasSearched) return;
if (!HasSearched)
return;
UpdateResultsText();
}
partial void OnSelectedPeriodChanged(CivitPeriod value)
{
TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions(
value, SortMode, SelectedModelType, SelectedBaseModelType));
settingsManager.Transaction(
s =>
s.ModelSearchOptions = new ModelSearchOptions(
value,
SortMode,
SelectedModelType,
SelectedBaseModelType
)
);
}
partial void OnSortModeChanged(CivitSortMode value)
{
TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod, value, SelectedModelType, SelectedBaseModelType));
settingsManager.Transaction(
s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
value,
SelectedModelType,
SelectedBaseModelType
)
);
}
partial void OnSelectedModelTypeChanged(CivitModelType value)
{
TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod, SortMode, value, SelectedBaseModelType));
settingsManager.Transaction(
s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
value,
SelectedBaseModelType
)
);
}
partial void OnSelectedBaseModelTypeChanged(string value)
{
TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod, SortMode, SelectedModelType, value));
settingsManager.Transaction(
s =>
s.ModelSearchOptions = new ModelSearchOptions(
SelectedPeriod,
SortMode,
SelectedModelType,
value
)
);
}
private async Task TrySearchAgain(bool shouldUpdatePageNumber = true)
{
if (!HasSearched) return;
if (!HasSearched)
return;
ModelCards?.Clear();
if (shouldUpdatePageNumber)
@ -483,11 +613,13 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
private void UpdateResultsText()
{
NoResultsFound = ModelCards?.Count <= 0;
NoResultsText = allModelCards.Count > 0
? $"{allModelCards.Count} results hidden by filters"
: "No results found";
NoResultsText =
allModelCards.Count > 0
? $"{allModelCards.Count} results hidden by filters"
: "No results found";
}
public override string Title => "Model Browser";
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true };
public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.BrainCircuit, IsFilled = true };
}

Loading…
Cancel
Save