Browse Source

added base model filter to checkpoints page (wip)

and some other bug fixes n improvements n stuff idk read the chagenlog
pull/438/head
JT 11 months ago
parent
commit
f0fba442ee
  1. 7
      CHANGELOG.md
  2. 5
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  3. 16
      StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs
  4. 12
      StabilityMatrix.Avalonia/ViewModels/CheckpointManager/BaseModelOptionViewModel.cs
  5. 50
      StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
  6. 62
      StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
  7. 8
      StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs
  8. 2
      StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs
  9. 7
      StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs
  10. 96
      StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs
  11. 34
      StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
  12. 32
      StabilityMatrix.Core/Models/Api/CivitBaseModelType.cs
  13. 2
      StabilityMatrix.Core/Services/MetadataImportService.cs

7
CHANGELOG.md

@ -5,6 +5,13 @@ All notable changes to Stability Matrix will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2.0.0.html).
## v2.8.0-pre.1
### Added
- Added base model filter to Checkpoints page
### Fixed
- Inference file name patterns with directory separator characters will now have the subdirectories created automatically
- Fixed missing up/downgrade buttons on the Python Packages dialog when the version was not semver compatible
## v2.8.0-dev.4
### Added
- Auto-update support for macOS

5
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -181,6 +181,11 @@ public abstract partial class InferenceGenerationViewModelBase
file = outputDir.JoinFile($"{fileName}_{uuid}.{fileExtension}");
}
if (file.Info.DirectoryName != null)
{
Directory.CreateDirectory(file.Info.DirectoryName);
}
await using var fileStream = file.Info.OpenWrite();
await imageStream.CopyToAsync(fileStream);

16
StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs

@ -129,18 +129,8 @@ public partial class CivitAiBrowserViewModel : TabViewModelBase
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString());
public List<string> BaseModelOptions =>
[
"All",
"SD 1.5",
"SD 1.5 LCM",
"SD 2.1",
"SDXL 0.9",
"SDXL 1.0",
"SDXL 1.0 LCM",
"SDXL Turbo",
"Other"
];
public IEnumerable<string> BaseModelOptions =>
Enum.GetValues<CivitBaseModelType>().Select(t => t.GetStringValue());
public CivitAiBrowserViewModel(
ICivitApi civitApi,
@ -424,7 +414,7 @@ public partial class CivitAiBrowserViewModel : TabViewModelBase
if (SelectedModelType != CivitModelType.All)
{
modelRequest.Types = new[] { SelectedModelType };
modelRequest.Types = [SelectedModelType];
}
if (SelectedBaseModelType != "All")

12
StabilityMatrix.Avalonia/ViewModels/CheckpointManager/BaseModelOptionViewModel.cs

@ -0,0 +1,12 @@
using CommunityToolkit.Mvvm.ComponentModel;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
public partial class BaseModelOptionViewModel : ObservableObject
{
[ObservableProperty]
private bool isSelected;
[ObservableProperty]
private string modelType = string.Empty;
}

50
StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs

@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Collections.Specialized;
using System.IO;
using System.Linq;
@ -22,6 +23,7 @@ using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
@ -85,6 +87,14 @@ public partial class CheckpointFolder : ViewModelBase
[ObservableProperty]
private string searchFilter = string.Empty;
[ObservableProperty]
private ObservableCollection<string> baseModelOptions =
new(
Enum.GetValues<CivitBaseModelType>()
.Where(x => x != CivitBaseModelType.All)
.Select(x => x.GetStringValue())
);
public bool IsDragBlurEnabled => IsCurrentDragTarget || IsImportInProgress;
public string TitleWithFilesCount =>
@ -141,6 +151,7 @@ public partial class CheckpointFolder : ViewModelBase
f.FileName.Contains(SearchFilter, StringComparison.OrdinalIgnoreCase)
|| f.Title.Contains(SearchFilter, StringComparison.OrdinalIgnoreCase)
)
.Filter(BaseModelFilter)
.Bind(DisplayedCheckpointFiles)
.Subscribe();
@ -161,6 +172,13 @@ public partial class CheckpointFolder : ViewModelBase
// DisplayedCheckpointFiles = CheckpointFiles;
}
private bool BaseModelFilter(CheckpointFile file)
{
return file.IsConnectedModel
? BaseModelOptions.Contains(file.ConnectedModel!.BaseModel)
: BaseModelOptions.Contains("Other");
}
/// <summary>
/// When title is set, set the category enabled state from settings.
/// </summary>
@ -187,6 +205,16 @@ public partial class CheckpointFolder : ViewModelBase
checkpointFilesCache.Refresh();
}
partial void OnBaseModelOptionsChanged(ObservableCollection<string> value)
{
foreach (var subFolder in SubFolders)
{
subFolder.BaseModelOptions = new ObservableCollection<string>(value);
}
checkpointFilesCache.Refresh();
}
/// <summary>
/// When toggling the category enabled state, save it to settings.
/// </summary>
@ -393,7 +421,9 @@ public partial class CheckpointFolder : ViewModelBase
Progress.Value = report.Percentage;
// For multiple files, add count
Progress.Text =
copyPaths.Count > 1 ? $"Importing {report.Title} ({report.Message})" : $"Importing {report.Title}";
copyPaths.Count > 1
? $"Importing {report.Title} ({report.Message})"
: $"Importing {report.Title}";
});
await FileTransfers.CopyFiles(copyPaths, progress);
@ -603,15 +633,13 @@ public partial class CheckpointFolder : ViewModelBase
SubFoldersCache.EditDiff(updatedFolders, (a, b) => a.Title == b.Title);
// Index files
Dispatcher
.UIThread
.Post(
() =>
{
var files = GetCheckpointFiles();
checkpointFilesCache.EditDiff(files, CheckpointFile.FilePathComparer);
},
DispatcherPriority.Background
);
Dispatcher.UIThread.Post(
() =>
{
var files = GetCheckpointFiles();
checkpointFilesCache.EditDiff(files, CheckpointFile.FilePathComparer);
},
DispatcherPriority.Background
);
}
}

62
StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs

@ -1,8 +1,10 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
@ -17,7 +19,9 @@ using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Services;
@ -41,7 +45,19 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
public override string Title => "Checkpoints";
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.Notebook, IsFilled = true };
public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.Notebook, IsFilled = true };
[ObservableProperty]
private ObservableCollection<string> baseModelOptions =
new(
Enum.GetValues<CivitBaseModelType>()
.Where(x => x != CivitBaseModelType.All)
.Select(x => x.GetStringValue())
);
[ObservableProperty]
private ObservableCollection<string> selectedBaseModels = [];
// Toggle button for auto hashing new drag-and-dropped files for connected upgrade
[ObservableProperty]
@ -97,12 +113,24 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
this.metadataImportService = metadataImportService;
this.modelFinder = modelFinder;
SelectedBaseModels = new ObservableCollection<string>(BaseModelOptions);
SelectedBaseModels.CollectionChanged += (sender, args) =>
{
foreach (var folder in CheckpointFolders)
{
folder.BaseModelOptions = new ObservableCollection<string>(SelectedBaseModels);
}
CheckpointFoldersCache.Refresh();
};
CheckpointFoldersCache
.Connect()
.DeferUntilLoaded()
.SortBy(x => x.Title)
.Bind(CheckpointFolders)
.Filter(ContainsSearchFilter)
.Filter(ContainsBaseModel)
.Bind(DisplayedCheckpointFolders)
.Subscribe();
}
@ -165,8 +193,7 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
private bool ContainsSearchFilter(CheckpointFolder folder)
{
if (folder == null)
throw new ArgumentNullException(nameof(folder));
ArgumentNullException.ThrowIfNull(folder);
if (string.IsNullOrWhiteSpace(SearchFilter))
{
@ -174,12 +201,28 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
}
// Check files in the current folder
return folder.CheckpointFiles.Any(x => x.FileName.Contains(SearchFilter, StringComparison.OrdinalIgnoreCase))
return folder.CheckpointFiles.Any(
x =>
x.FileName.Contains(SearchFilter, StringComparison.OrdinalIgnoreCase)
|| x.ConnectedModel?.ModelName.Contains(SearchFilter, StringComparison.OrdinalIgnoreCase)
== true
|| x.ConnectedModel?.Tags.Any(
t => t.Contains(SearchFilter, StringComparison.OrdinalIgnoreCase)
) == true
)
||
// If no matching files were found in the current folder, check in all subfolders
folder.SubFolders.Any(ContainsSearchFilter);
}
private bool ContainsBaseModel(CheckpointFolder folder)
{
ArgumentNullException.ThrowIfNull(folder);
return folder.CheckpointFiles.Any(x => SelectedBaseModels.Contains(x.ConnectedModel?.BaseModel))
|| folder.SubFolders.Any(ContainsBaseModel);
}
private void IndexFolders()
{
var modelsDirectory = settingsManager.ModelsDirectory;
@ -241,9 +284,16 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
Progress = report;
});
await metadataImportService.ScanDirectoryForMissingInfo(settingsManager.ModelsDirectory, progressHandler);
await metadataImportService.ScanDirectoryForMissingInfo(
settingsManager.ModelsDirectory,
progressHandler
);
notificationService.Show("Scan Complete", "Finished scanning for missing metadata.", NotificationType.Success);
notificationService.Show(
"Scan Complete",
"Finished scanning for missing metadata.",
NotificationType.Success
);
DelayedClearProgress(TimeSpan.FromSeconds(1.5));
}

8
StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs

@ -1,4 +1,5 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Avalonia.Controls;
@ -49,8 +50,9 @@ public partial class PythonPackagesItemViewModel : ViewModelBase
|| !SemVersion.TryParse(value, out var selectedSemver)
)
{
CanUpgrade = false;
CanDowngrade = false;
var compare = string.CompareOrdinal(value, Package.Version);
CanUpgrade = compare > 0;
CanDowngrade = compare < 0;
return;
}

2
StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs

@ -1,9 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using Newtonsoft.Json;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;

7
StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs

@ -1,14 +1,11 @@
using System.Linq;
using System.Text.Json.Nodes;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using Newtonsoft.Json;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
namespace StabilityMatrix.Avalonia.ViewModels.Inference;

96
StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs

@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
@ -34,35 +35,22 @@ namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(NewCheckpointsPage))]
[Singleton]
public partial class NewCheckpointsPageViewModel : PageViewModelBase
public partial class NewCheckpointsPageViewModel(
ILogger<NewCheckpointsPageViewModel> logger,
ISettingsManager settingsManager,
ILiteDbContext liteDbContext,
ICivitApi civitApi,
ServiceManager<ViewModelBase> dialogFactory,
INotificationService notificationService,
IDownloadService downloadService,
ModelFinder modelFinder,
IMetadataImportService metadataImportService
) : PageViewModelBase
{
private readonly ILogger<NewCheckpointsPageViewModel> logger;
private readonly ISettingsManager settingsManager;
private readonly ILiteDbContext liteDbContext;
private readonly ICivitApi civitApi;
private readonly ServiceManager<ViewModelBase> dialogFactory;
private readonly INotificationService notificationService;
public override string Title => "Checkpoint Manager";
public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.Cellular5g, IsFilled = true };
public NewCheckpointsPageViewModel(
ILogger<NewCheckpointsPageViewModel> logger,
ISettingsManager settingsManager,
ILiteDbContext liteDbContext,
ICivitApi civitApi,
ServiceManager<ViewModelBase> dialogFactory,
INotificationService notificationService
)
{
this.logger = logger;
this.settingsManager = settingsManager;
this.liteDbContext = liteDbContext;
this.civitApi = civitApi;
this.dialogFactory = dialogFactory;
this.notificationService = notificationService;
}
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(ConnectedCheckpoints))]
[NotifyPropertyChangedFor(nameof(NonConnectedCheckpoints))]
@ -89,7 +77,61 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
if (Design.IsDesignMode)
return;
var files = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory);
var files = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory).ToList();
var uniqueSubFolders = files
.Select(
x =>
x.FilePath.Replace(settingsManager.ModelsDirectory, string.Empty)
.Replace(x.FileName, string.Empty)
.Trim(Path.DirectorySeparatorChar)
)
.Distinct()
.Where(x => x.Contains(Path.DirectorySeparatorChar))
.Where(x => Directory.Exists(Path.Combine(settingsManager.ModelsDirectory, x)))
.ToList();
var checkpointFolders = Enum.GetValues<SharedFolderType>()
.Where(x => Directory.Exists(Path.Combine(settingsManager.ModelsDirectory, x.ToString())))
.Select(
folderType =>
new CheckpointFolder(
settingsManager,
downloadService,
modelFinder,
notificationService,
metadataImportService
)
{
Title = folderType.ToString(),
DirectoryPath = Path.Combine(settingsManager.ModelsDirectory, folderType.ToString()),
FolderType = folderType,
IsExpanded = true,
}
)
.ToList();
foreach (var folder in uniqueSubFolders)
{
var folderType = Enum.Parse<SharedFolderType>(folder.Split(Path.DirectorySeparatorChar)[0]);
var parentFolder = checkpointFolders.FirstOrDefault(x => x.FolderType == folderType);
var checkpointFolder = new CheckpointFolder(
settingsManager,
downloadService,
modelFinder,
notificationService,
metadataImportService
)
{
Title = folderType.ToString(),
DirectoryPath = Path.Combine(settingsManager.ModelsDirectory, folder),
FolderType = folderType,
ParentFolder = parentFolder,
IsExpanded = true,
};
parentFolder?.SubFolders.Add(checkpointFolder);
}
AllCheckpoints = new ObservableCollection<CheckpointFile>(files);
var connectedModelIds = ConnectedCheckpoints.Select(x => x.ConnectedModel.ModelId);
@ -99,8 +141,8 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
};
// See if query is cached
var cachedQuery = await liteDbContext.CivitModelQueryCache
.IncludeAll()
var cachedQuery = await liteDbContext
.CivitModelQueryCache.IncludeAll()
.FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest));
// If cached, update model cards

34
StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml

@ -10,6 +10,9 @@
xmlns:checkpointManager="clr-namespace:StabilityMatrix.Avalonia.ViewModels.CheckpointManager"
xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
xmlns:avalonia="https://github.com/projektanker/icons.avalonia"
xmlns:api="clr-namespace:StabilityMatrix.Core.Models.Api;assembly=StabilityMatrix.Core"
xmlns:generic="clr-namespace:System.Collections.Generic;assembly=System.Collections"
xmlns:converters="clr-namespace:StabilityMatrix.Avalonia.Converters"
d:DataContext="{x:Static mocks:DesignData.CheckpointsPageViewModel}"
x:CompileBindings="True"
x:DataType="vm:CheckpointsPageViewModel"
@ -24,6 +27,8 @@
Color="#FF000000"
Opacity="0.2"
x:Key="TextDropShadowEffect" />
<converters:EnumStringConverter x:Key="EnumStringConverter" />
<!-- Template for a single badge -->
<DataTemplate DataType="{x:Type system:String}" x:Key="BadgeTemplate">
@ -490,6 +495,35 @@
</ui:FAMenuFlyout>
</DropDownButton.Flyout>
</DropDownButton>
<DropDownButton
x:Name="BaseModelDropdown"
Content="{x:Static lang:Resources.Label_BaseModel}"
Margin="8,0"
VerticalAlignment="Center"
HorizontalAlignment="Right">
<DropDownButton.Flyout>
<Flyout>
<ListBox ItemsSource="{Binding BaseModelOptions}"
SelectionMode="Multiple, Toggle"
SelectedItems="{Binding SelectedBaseModels}">
<ListBox.Template>
<ControlTemplate>
<ItemsPresenter />
</ControlTemplate>
</ListBox.Template>
<ListBox.ItemTemplate>
<DataTemplate DataType="{x:Type checkpointManager:BaseModelOptionViewModel}">
<TextBlock Text="{Binding ModelType}"/>
</DataTemplate>
</ListBox.ItemTemplate>
</ListBox>
</Flyout>
</DropDownButton.Flyout>
</DropDownButton>
</StackPanel>
<ui:CommandBar
Grid.Row="0"

32
StabilityMatrix.Core/Models/Api/CivitBaseModelType.cs

@ -0,0 +1,32 @@
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Extensions;
namespace StabilityMatrix.Core.Models.Api;
[JsonConverter(typeof(JsonStringEnumConverter<CivitBaseModelType>))]
public enum CivitBaseModelType
{
All,
[StringValue("SD 1.5")]
Sd15,
[StringValue("SD 1.5 LCM")]
Sd15Lcm,
[StringValue("SD 2.1")]
Sd21,
[StringValue("SDXL 0.9")]
Sdxl09,
[StringValue("SDXL 1.0")]
Sdxl10,
[StringValue("SDXL 1.0 LCM")]
Sdxl10Lcm,
[StringValue("SDXL Turbo")]
SdxlTurbo,
Other,
}

2
StabilityMatrix.Core/Services/MetadataImportService.cs

@ -237,7 +237,7 @@ public class MetadataImportService(
new ProgressReport(
current: report.Current ?? 0,
total: report.Total ?? 0,
$"Getting metadata for {filePath} ... {report.Percentage}%"
$"Getting metadata for {fileNameWithoutExtension} ... {report.Percentage}%"
)
);
});

Loading…
Cancel
Save