Browse Source

fix build

pull/55/head
JT 1 year ago
parent
commit
cfc1f147d3
  1. 17
      StabilityMatrix.Avalonia/ViewModels/CheckpointFolder.cs
  2. 24
      StabilityMatrix.Avalonia/ViewModels/ProgressViewModel.cs
  3. 2
      StabilityMatrix.Core/Helper/SharedFolders.cs
  4. 3
      StabilityMatrix/CheckpointManagerPage.xaml
  5. 4
      StabilityMatrix/DesignData/MockCheckpointFolder.cs
  6. 226
      StabilityMatrix/Models/CheckpointFile.cs
  7. 332
      StabilityMatrix/Models/CheckpointFolder.cs
  8. 4
      StabilityMatrix/StabilityMatrix.csproj
  9. 1
      StabilityMatrix/ViewModels/CheckpointManagerViewModel.cs
  10. 1
      StabilityMatrix/ViewModels/TextToImageViewModel.cs

17
StabilityMatrix.Avalonia/ViewModels/CheckpointFolder.cs

@ -6,6 +6,8 @@ using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Avalonia.Input;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Core.Extensions;
@ -19,7 +21,6 @@ namespace StabilityMatrix.Avalonia.ViewModels;
public partial class CheckpointFolder : ObservableObject
{
private readonly IDialogFactory dialogFactory;
private readonly ISettingsManager settingsManager;
private readonly IDownloadService downloadService;
private readonly ModelFinder modelFinder;
@ -35,7 +36,7 @@ public partial class CheckpointFolder : ObservableObject
/// Custom title for UI.
/// </summary>
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(Models.CheckpointFolder.FolderType))]
[NotifyPropertyChangedFor(nameof(FolderType))]
[NotifyPropertyChangedFor(nameof(TitleWithFilesCount))]
private string title = string.Empty;
@ -68,13 +69,11 @@ public partial class CheckpointFolder : ObservableObject
public RelayCommand OnPreviewDragLeaveCommand => new(() => IsCurrentDragTarget = false);
public CheckpointFolder(
IDialogFactory dialogFactory,
ISettingsManager settingsManager,
IDownloadService downloadService,
ModelFinder modelFinder,
bool useCategoryVisibility = true)
{
this.dialogFactory = dialogFactory;
this.settingsManager = settingsManager;
this.downloadService = downloadService;
this.modelFinder = modelFinder;
@ -129,7 +128,7 @@ public partial class CheckpointFolder : ObservableObject
/// <param name="file"></param>
private void OnCheckpointFileDelete(object? sender, CheckpointFile file)
{
Application.Current.Dispatcher.Invoke(() => CheckpointFiles.Remove(file));
Dispatcher.UIThread.Invoke(() => CheckpointFiles.Remove(file));
}
[RelayCommand]
@ -138,7 +137,7 @@ public partial class CheckpointFolder : ObservableObject
IsImportInProgress = true;
IsCurrentDragTarget = false;
if (e.Data.GetData(DataFormats.FileDrop) is not string[] files || files.Length < 1)
if (e.Data.Get(DataFormats.Files) is not string[] files || files.Length < 1)
{
IsImportInProgress = false;
return;
@ -290,10 +289,10 @@ public partial class CheckpointFolder : ObservableObject
return await (progress switch
{
null => Task.Run(() =>
CheckpointFile.FromDirectoryIndex(dialogFactory, DirectoryPath).ToList()),
CheckpointFile.FromDirectoryIndex(DirectoryPath).ToList()),
_ => Task.Run(() =>
CheckpointFile.FromDirectoryIndex(dialogFactory, DirectoryPath, progress).ToList())
CheckpointFile.FromDirectoryIndex(DirectoryPath, progress).ToList())
});
}
@ -306,7 +305,7 @@ public partial class CheckpointFolder : ObservableObject
foreach (var folder in Directory.GetDirectories(DirectoryPath))
{
// Inherit our folder type
var subFolder = new CheckpointFolder(dialogFactory, settingsManager,
var subFolder = new CheckpointFolder(settingsManager,
downloadService, modelFinder,
useCategoryVisibility: false)
{

24
StabilityMatrix.Avalonia/ViewModels/ProgressViewModel.cs

@ -0,0 +1,24 @@
using CommunityToolkit.Mvvm.ComponentModel;
namespace StabilityMatrix.Avalonia.ViewModels;
/// <summary>
/// Generic view model for progress reporting.
/// </summary>
public partial class ProgressViewModel : ObservableObject
{
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsTextVisible))]
private string text;
[ObservableProperty]
private double value;
[ObservableProperty]
private bool isIndeterminate;
[ObservableProperty]
private bool isProgressVisible;
public virtual bool IsTextVisible => string.IsNullOrWhiteSpace(Text);
}

2
StabilityMatrix.Core/Helper/SharedFolders.cs

@ -21,7 +21,7 @@ public class SharedFolders : ISharedFolders
this.packageFactory = packageFactory;
}
internal static void SetupLinks(Dictionary<SharedFolderType, string> definitions,
public static void SetupLinks(Dictionary<SharedFolderType, string> definitions,
DirectoryPath modelsDirectory, DirectoryPath installDirectory)
{
foreach (var (folderType, relativePath) in definitions)

3
StabilityMatrix/CheckpointManagerPage.xaml

@ -19,7 +19,8 @@
xmlns:system="clr-namespace:System;assembly=System.Runtime"
xmlns:ui="http://schemas.lepo.co/wpfui/2022/xaml"
xmlns:viewModels="clr-namespace:StabilityMatrix.ViewModels"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml">
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:models="clr-namespace:StabilityMatrix.Models">
<Page.Resources>
<BooleanToVisibilityConverter x:Key="BoolToVisibilityConverter" />

4
StabilityMatrix/DesignData/MockCheckpointFolder.cs

@ -1,4 +1,6 @@
namespace StabilityMatrix.DesignData;
using StabilityMatrix.Models;
namespace StabilityMatrix.DesignData;
public class MockCheckpointFolder : CheckpointFolder
{

226
StabilityMatrix/Models/CheckpointFile.cs

@ -0,0 +1,226 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using System.Windows;
using System.Windows.Media.Imaging;
using AsyncAwaitBestPractices;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using NLog;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Helper;
namespace StabilityMatrix.Models;
public partial class CheckpointFile : ObservableObject
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly IDialogFactory dialogFactory;
// Event for when this file is deleted
public event EventHandler<CheckpointFile>? Deleted;
/// <summary>
/// Absolute path to the checkpoint file.
/// </summary>
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(FileName))]
private string filePath = string.Empty;
/// <summary>
/// Custom title for UI.
/// </summary>
[ObservableProperty]
private string title = string.Empty;
public string? PreviewImagePath { get; set; }
public BitmapImage? PreviewImage { get; set; }
public bool IsPreviewImageLoaded => PreviewImage != null;
[ObservableProperty]
private ConnectedModelInfo? connectedModel;
public bool IsConnectedModel => ConnectedModel != null;
[ObservableProperty] private bool isLoading;
public string FileName => Path.GetFileName(FilePath);
public ObservableCollection<string> Badges { get; set; } = new();
private static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt", ".ckpt", ".pth", "bin" };
private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg" };
private static readonly string[] SupportedMetadataExtensions = { ".json" };
public CheckpointFile(IDialogFactory dialogFactory)
{
this.dialogFactory = dialogFactory;
}
partial void OnConnectedModelChanged(ConnectedModelInfo? value)
{
// Update title, first check user defined, then connected model name
Title = value?.UserTitle ?? value?.ModelName ?? string.Empty;
// Update badges
Badges.Clear();
var fpType = value.FileMetadata.Fp?.GetStringValue().ToUpperInvariant();
if (fpType != null)
{
Badges.Add(fpType);
}
if (!string.IsNullOrWhiteSpace(value.BaseModel))
{
Badges.Add(value.BaseModel);
}
}
[RelayCommand]
private async Task DeleteAsync()
{
if (File.Exists(FilePath))
{
IsLoading = true;
try
{
await using var delay = new MinimumDelay(200, 500);
await Task.Run(() => File.Delete(FilePath));
if (PreviewImagePath != null && File.Exists(PreviewImagePath))
{
await Task.Run(() => File.Delete(PreviewImagePath));
}
}
catch (IOException ex)
{
Logger.Warn($"Failed to delete checkpoint file {FilePath}: {ex.Message}");
return; // Don't delete from collection
}
finally
{
IsLoading = false;
}
}
Deleted?.Invoke(this, this);
}
[RelayCommand]
private async Task RenameAsync()
{
var responses = await dialogFactory.ShowTextEntryDialog("Rename Model", new []
{
("File Name", FileName)
});
var name = responses?.FirstOrDefault();
if (name == null) return;
// Rename file in OS
try
{
var newFilePath = Path.Combine(Path.GetDirectoryName(FilePath) ?? "", name);
File.Move(FilePath, newFilePath);
FilePath = newFilePath;
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}
}
[RelayCommand]
private void OpenOnCivitAi()
{
ProcessRunner.OpenUrl($"https://civitai.com/models/{ConnectedModel.ModelId}");
}
// Loads image from path
private async Task LoadPreviewImage()
{
if (PreviewImagePath == null) return;
var bytes = await File.ReadAllBytesAsync(PreviewImagePath);
await Application.Current.Dispatcher.InvokeAsync(() =>
{
var bitmap = new BitmapImage();
using var ms = new MemoryStream(bytes);
bitmap.BeginInit();
bitmap.StreamSource = ms;
bitmap.CacheOption = BitmapCacheOption.OnLoad;
bitmap.EndInit();
PreviewImage = bitmap;
});
}
/// <summary>
/// Indexes directory and yields all checkpoint files.
/// First we match all files with supported extensions.
/// If found, we also look for
/// - {filename}.preview.{image-extensions} (preview image)
/// - {filename}.cm-info.json (connected model info)
/// </summary>
public static IEnumerable<CheckpointFile> FromDirectoryIndex(IDialogFactory dialogFactory, string directory, SearchOption searchOption = SearchOption.TopDirectoryOnly)
{
// Get all files with supported extensions
var allExtensions = SupportedCheckpointExtensions
.Concat(SupportedImageExtensions)
.Concat(SupportedMetadataExtensions);
var files = allExtensions.AsParallel()
.SelectMany(pattern => Directory.EnumerateFiles(directory, $"*{pattern}", searchOption)).ToDictionary<string, string>(Path.GetFileName);
foreach (var file in files.Keys.Where(k => SupportedCheckpointExtensions.Contains(Path.GetExtension(k))))
{
var checkpointFile = new CheckpointFile(dialogFactory)
{
Title = Path.GetFileNameWithoutExtension(file),
FilePath = Path.Combine(directory, file),
};
// Check for connected model info
var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(file);
var cmInfoPath = $"{fileNameWithoutExtension}.cm-info.json";
if (files.TryGetValue(cmInfoPath, out var jsonPath))
{
try
{
var jsonData = File.ReadAllText(jsonPath);
checkpointFile.ConnectedModel = ConnectedModelInfo.FromJson(jsonData);
}
catch (IOException e)
{
Debug.WriteLine($"Failed to parse {cmInfoPath}: {e}");
}
}
// Check for preview image
var previewImage = SupportedImageExtensions.Select(ext => $"{fileNameWithoutExtension}.preview{ext}").FirstOrDefault(files.ContainsKey);
if (previewImage != null)
{
checkpointFile.PreviewImagePath = files[previewImage];
checkpointFile.LoadPreviewImage().SafeFireAndForget();
}
yield return checkpointFile;
}
}
/// <summary>
/// Index with progress reporting.
/// </summary>
public static IEnumerable<CheckpointFile> FromDirectoryIndex(IDialogFactory dialogFactory, string directory, IProgress<ProgressReport> progress,
SearchOption searchOption = SearchOption.TopDirectoryOnly)
{
var current = 0ul;
foreach (var checkpointFile in FromDirectoryIndex(dialogFactory, directory, searchOption))
{
current++;
progress.Report(new ProgressReport(current, "Indexing", checkpointFile.FileName));
yield return checkpointFile;
}
}
}

332
StabilityMatrix/Models/CheckpointFolder.cs

@ -0,0 +1,332 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Collections.Specialized;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using System.Windows;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Services;
using StabilityMatrix.Helper;
using StabilityMatrix.ViewModels;
namespace StabilityMatrix.Models;
public partial class CheckpointFolder : ObservableObject
{
private readonly IDialogFactory dialogFactory;
private readonly ISettingsManager settingsManager;
private readonly IDownloadService downloadService;
private readonly ModelFinder modelFinder;
// ReSharper disable once FieldCanBeMadeReadOnly.Local
private bool useCategoryVisibility;
/// <summary>
/// Absolute path to the folder.
/// </summary>
public string DirectoryPath { get; init; } = string.Empty;
/// <summary>
/// Custom title for UI.
/// </summary>
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(FolderType))]
[NotifyPropertyChangedFor(nameof(TitleWithFilesCount))]
private string title = string.Empty;
[ObservableProperty]
private SharedFolderType folderType;
/// <summary>
/// True if the category is enabled for the manager page.
/// </summary>
[ObservableProperty]
private bool isCategoryEnabled = true;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsDragBlurEnabled))]
private bool isCurrentDragTarget;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsDragBlurEnabled))]
private bool isImportInProgress;
public bool IsDragBlurEnabled => IsCurrentDragTarget || IsImportInProgress;
public string TitleWithFilesCount => CheckpointFiles.Any() ? $"{Title} ({CheckpointFiles.Count})" : Title;
public ProgressViewModel Progress { get; } = new();
public ObservableCollection<CheckpointFolder> SubFolders { get; init; } = new();
public ObservableCollection<CheckpointFile> CheckpointFiles { get; init; } = new();
public RelayCommand OnPreviewDragEnterCommand => new(() => IsCurrentDragTarget = true);
public RelayCommand OnPreviewDragLeaveCommand => new(() => IsCurrentDragTarget = false);
public CheckpointFolder(
IDialogFactory dialogFactory,
ISettingsManager settingsManager,
IDownloadService downloadService,
ModelFinder modelFinder,
bool useCategoryVisibility = true)
{
this.dialogFactory = dialogFactory;
this.settingsManager = settingsManager;
this.downloadService = downloadService;
this.modelFinder = modelFinder;
this.useCategoryVisibility = useCategoryVisibility;
CheckpointFiles.CollectionChanged += OnCheckpointFilesChanged;
}
/// <summary>
/// When title is set, set the category enabled state from settings.
/// </summary>
// ReSharper disable once UnusedParameterInPartialMethod
partial void OnTitleChanged(string value)
{
if (!useCategoryVisibility) return;
// Update folder type
var result = Enum.TryParse(Title, out SharedFolderType type);
FolderType = result ? type : new SharedFolderType();
IsCategoryEnabled = settingsManager.IsSharedFolderCategoryVisible(FolderType);
}
/// <summary>
/// When toggling the category enabled state, save it to settings.
/// </summary>
partial void OnIsCategoryEnabledChanged(bool value)
{
if (!useCategoryVisibility) return;
if (value != settingsManager.IsSharedFolderCategoryVisible(FolderType))
{
settingsManager.SetSharedFolderCategoryVisible(FolderType, value);
}
}
// On collection changes
private void OnCheckpointFilesChanged(object? sender, NotifyCollectionChangedEventArgs e)
{
OnPropertyChanged(nameof(TitleWithFilesCount));
if (e.NewItems == null) return;
// On new added items, add event handler for deletion
foreach (CheckpointFile item in e.NewItems)
{
item.Deleted += OnCheckpointFileDelete;
}
}
/// <summary>
/// Handler for CheckpointFile requesting to be deleted from the collection.
/// </summary>
/// <param name="sender"></param>
/// <param name="file"></param>
private void OnCheckpointFileDelete(object? sender, CheckpointFile file)
{
Application.Current.Dispatcher.Invoke(() => CheckpointFiles.Remove(file));
}
[RelayCommand]
private async Task OnPreviewDropAsync(DragEventArgs e)
{
IsImportInProgress = true;
IsCurrentDragTarget = false;
if (e.Data.GetData(DataFormats.FileDrop) is not string[] files || files.Length < 1)
{
IsImportInProgress = false;
return;
}
await ImportFilesAsync(files, settingsManager.Settings.IsImportAsConnected);
}
[RelayCommand]
private void ShowInExplorer(string path)
{
Process.Start("explorer.exe", path);
}
/// <summary>
/// Imports files to the folder. Reports progress to instance properties.
/// </summary>
public async Task ImportFilesAsync(IEnumerable<string> files, bool convertToConnected = false)
{
Progress.IsIndeterminate = true;
Progress.IsProgressVisible = true;
var copyPaths = files.ToDictionary(k => k, v => Path.Combine(DirectoryPath, Path.GetFileName(v)));
var progress = new Progress<ProgressReport>(report =>
{
Progress.IsIndeterminate = false;
Progress.Value = report.Percentage;
// For multiple files, add count
Progress.Text = copyPaths.Count > 1 ? $"Importing {report.Title} ({report.Message})" : $"Importing {report.Title}";
});
await FileTransfers.CopyFiles(copyPaths, progress);
// Hash files and convert them to connected model if found
if (convertToConnected)
{
var modelFilesCount = copyPaths.Count;
var modelFiles = copyPaths.Values
.Select(path => new FilePath(path));
// Holds tasks for model queries after hash
var modelQueryTasks = new List<Task<bool>>();
foreach (var (i, modelFile) in modelFiles.Enumerate())
{
var hashProgress = new Progress<ProgressReport>(report =>
{
Progress.IsIndeterminate = false;
Progress.Value = report.Percentage;
Progress.Text = modelFilesCount > 1 ?
$"Computing metadata for {modelFile.Info.Name} ({i}/{modelFilesCount})" :
$"Computing metadata for {report.Title}";
});
var hashBlake3 = await FileHash.GetBlake3Async(modelFile, hashProgress);
// Start a task to query the model in background
var queryTask = Task.Run(async () =>
{
var result = await modelFinder.LocalFindModel(hashBlake3);
result ??= await modelFinder.RemoteFindModel(hashBlake3);
if (result is null) return false; // Not found
var (model, version, file) = result.Value;
// Save connected model info json
var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Info.Name);
var modelInfo = new ConnectedModelInfo(
model, version, file, DateTimeOffset.UtcNow);
await modelInfo.SaveJsonToDirectory(DirectoryPath, modelFileName);
// If available, save thumbnail
var image = version.Images?.FirstOrDefault();
if (image != null)
{
var imageExt = Path.GetExtension(image.Url).TrimStart('.');
if (imageExt is "jpg" or "jpeg" or "png")
{
var imageDownloadPath = Path.GetFullPath(
Path.Combine(DirectoryPath, $"{modelFileName}.preview.{imageExt}"));
await downloadService.DownloadToFileAsync(image.Url, imageDownloadPath);
}
}
return true;
});
modelQueryTasks.Add(queryTask);
}
// Set progress to indeterminate
Progress.IsIndeterminate = true;
Progress.Text = "Checking connected model information";
// Wait for all model queries to finish
var modelQueryResults = await Task.WhenAll(modelQueryTasks);
var successCount = modelQueryResults.Count(r => r);
var totalCount = modelQueryResults.Length;
var failCount = totalCount - successCount;
await IndexAsync();
Progress.Value = 100;
Progress.Text = successCount switch
{
0 when failCount > 0 =>
"Import complete. No connected data found.",
> 0 when failCount > 0 =>
$"Import complete. Found connected data for {successCount} of {totalCount} models.",
_ => $"Import complete. Found connected data for all {totalCount} models."
};
DelayedClearProgress(TimeSpan.FromSeconds(1));
}
else
{
await IndexAsync();
Progress.Value = 100;
Progress.Text = "Import complete";
DelayedClearProgress(TimeSpan.FromSeconds(1));
}
}
/// <summary>
/// Clears progress after a delay.
/// </summary>
private void DelayedClearProgress(TimeSpan delay)
{
Task.Delay(delay).ContinueWith(_ =>
{
IsImportInProgress = false;
Progress.IsProgressVisible = false;
Progress.Value = 0;
Progress.Text = string.Empty;
});
}
/// <summary>
/// Gets checkpoint files from folder index
/// </summary>
private async Task<List<CheckpointFile>> GetCheckpointFilesAsync(IProgress<ProgressReport>? progress = default)
{
if (!Directory.Exists(DirectoryPath))
{
return new List<CheckpointFile>();
}
return await (progress switch
{
null => Task.Run(() =>
CheckpointFile.FromDirectoryIndex(dialogFactory, DirectoryPath).ToList()),
_ => Task.Run(() =>
CheckpointFile.FromDirectoryIndex(dialogFactory, DirectoryPath, progress).ToList())
});
}
/// <summary>
/// Indexes the folder for checkpoint files and refreshes the CheckPointFiles collection.
/// </summary>
public async Task IndexAsync(IProgress<ProgressReport>? progress = default)
{
SubFolders.Clear();
foreach (var folder in Directory.GetDirectories(DirectoryPath))
{
// Inherit our folder type
var subFolder = new CheckpointFolder(dialogFactory, settingsManager,
downloadService, modelFinder,
useCategoryVisibility: false)
{
Title = Path.GetFileName(folder),
DirectoryPath = folder,
FolderType = FolderType
};
await subFolder.IndexAsync(progress);
SubFolders.Add(subFolder);
}
var checkpointFiles = await GetCheckpointFilesAsync();
CheckpointFiles.Clear();
foreach (var checkpointFile in checkpointFiles)
{
CheckpointFiles.Add(checkpointFile);
}
}
}

4
StabilityMatrix/StabilityMatrix.csproj

@ -111,8 +111,4 @@
<ProjectReference Include="..\StabilityMatrix.Core\StabilityMatrix.Core.csproj" />
</ItemGroup>
<ItemGroup>
<Folder Include="Models\" />
</ItemGroup>
</Project>

1
StabilityMatrix/ViewModels/CheckpointManagerViewModel.cs

@ -10,6 +10,7 @@ using NLog;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Services;
using StabilityMatrix.Helper;
using StabilityMatrix.Models;
namespace StabilityMatrix.ViewModels;

1
StabilityMatrix/ViewModels/TextToImageViewModel.cs

@ -14,6 +14,7 @@ using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Services;
using StabilityMatrix.Helper;
using StabilityMatrix.Models;
using StabilityMatrix.Services;
namespace StabilityMatrix.ViewModels;

Loading…
Cancel
Save