JT
1 year ago
10 changed files with 598 additions and 16 deletions
@ -0,0 +1,24 @@
|
||||
using CommunityToolkit.Mvvm.ComponentModel; |
||||
|
||||
namespace StabilityMatrix.Avalonia.ViewModels; |
||||
|
||||
/// <summary> |
||||
/// Generic view model for progress reporting. |
||||
/// </summary> |
||||
public partial class ProgressViewModel : ObservableObject |
||||
{ |
||||
[ObservableProperty] |
||||
[NotifyPropertyChangedFor(nameof(IsTextVisible))] |
||||
private string text; |
||||
|
||||
[ObservableProperty] |
||||
private double value; |
||||
|
||||
[ObservableProperty] |
||||
private bool isIndeterminate; |
||||
|
||||
[ObservableProperty] |
||||
private bool isProgressVisible; |
||||
|
||||
public virtual bool IsTextVisible => string.IsNullOrWhiteSpace(Text); |
||||
} |
@ -0,0 +1,226 @@
|
||||
using System; |
||||
using System.Collections.Generic; |
||||
using System.Collections.ObjectModel; |
||||
using System.Diagnostics; |
||||
using System.IO; |
||||
using System.Linq; |
||||
using System.Threading.Tasks; |
||||
using System.Windows; |
||||
using System.Windows.Media.Imaging; |
||||
using AsyncAwaitBestPractices; |
||||
using CommunityToolkit.Mvvm.ComponentModel; |
||||
using CommunityToolkit.Mvvm.Input; |
||||
using NLog; |
||||
using StabilityMatrix.Core.Extensions; |
||||
using StabilityMatrix.Core.Helper; |
||||
using StabilityMatrix.Core.Models; |
||||
using StabilityMatrix.Core.Models.Progress; |
||||
using StabilityMatrix.Core.Processes; |
||||
using StabilityMatrix.Helper; |
||||
|
||||
namespace StabilityMatrix.Models; |
||||
|
||||
public partial class CheckpointFile : ObservableObject |
||||
{ |
||||
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
||||
private readonly IDialogFactory dialogFactory; |
||||
|
||||
// Event for when this file is deleted |
||||
public event EventHandler<CheckpointFile>? Deleted; |
||||
|
||||
/// <summary> |
||||
/// Absolute path to the checkpoint file. |
||||
/// </summary> |
||||
[ObservableProperty] |
||||
[NotifyPropertyChangedFor(nameof(FileName))] |
||||
private string filePath = string.Empty; |
||||
|
||||
/// <summary> |
||||
/// Custom title for UI. |
||||
/// </summary> |
||||
[ObservableProperty] |
||||
private string title = string.Empty; |
||||
|
||||
public string? PreviewImagePath { get; set; } |
||||
public BitmapImage? PreviewImage { get; set; } |
||||
public bool IsPreviewImageLoaded => PreviewImage != null; |
||||
|
||||
[ObservableProperty] |
||||
private ConnectedModelInfo? connectedModel; |
||||
public bool IsConnectedModel => ConnectedModel != null; |
||||
|
||||
[ObservableProperty] private bool isLoading; |
||||
|
||||
public string FileName => Path.GetFileName(FilePath); |
||||
|
||||
public ObservableCollection<string> Badges { get; set; } = new(); |
||||
|
||||
private static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt", ".ckpt", ".pth", "bin" }; |
||||
private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg" }; |
||||
private static readonly string[] SupportedMetadataExtensions = { ".json" }; |
||||
|
||||
public CheckpointFile(IDialogFactory dialogFactory) |
||||
{ |
||||
this.dialogFactory = dialogFactory; |
||||
} |
||||
|
||||
partial void OnConnectedModelChanged(ConnectedModelInfo? value) |
||||
{ |
||||
// Update title, first check user defined, then connected model name |
||||
Title = value?.UserTitle ?? value?.ModelName ?? string.Empty; |
||||
// Update badges |
||||
Badges.Clear(); |
||||
var fpType = value.FileMetadata.Fp?.GetStringValue().ToUpperInvariant(); |
||||
if (fpType != null) |
||||
{ |
||||
Badges.Add(fpType); |
||||
} |
||||
if (!string.IsNullOrWhiteSpace(value.BaseModel)) |
||||
{ |
||||
Badges.Add(value.BaseModel); |
||||
} |
||||
} |
||||
|
||||
[RelayCommand] |
||||
private async Task DeleteAsync() |
||||
{ |
||||
if (File.Exists(FilePath)) |
||||
{ |
||||
IsLoading = true; |
||||
try |
||||
{ |
||||
await using var delay = new MinimumDelay(200, 500); |
||||
await Task.Run(() => File.Delete(FilePath)); |
||||
if (PreviewImagePath != null && File.Exists(PreviewImagePath)) |
||||
{ |
||||
await Task.Run(() => File.Delete(PreviewImagePath)); |
||||
} |
||||
} |
||||
catch (IOException ex) |
||||
{ |
||||
Logger.Warn($"Failed to delete checkpoint file {FilePath}: {ex.Message}"); |
||||
return; // Don't delete from collection |
||||
} |
||||
finally |
||||
{ |
||||
IsLoading = false; |
||||
} |
||||
} |
||||
Deleted?.Invoke(this, this); |
||||
} |
||||
|
||||
[RelayCommand] |
||||
private async Task RenameAsync() |
||||
{ |
||||
var responses = await dialogFactory.ShowTextEntryDialog("Rename Model", new [] |
||||
{ |
||||
("File Name", FileName) |
||||
}); |
||||
var name = responses?.FirstOrDefault(); |
||||
if (name == null) return; |
||||
|
||||
// Rename file in OS |
||||
try |
||||
{ |
||||
var newFilePath = Path.Combine(Path.GetDirectoryName(FilePath) ?? "", name); |
||||
File.Move(FilePath, newFilePath); |
||||
FilePath = newFilePath; |
||||
} |
||||
catch (Exception e) |
||||
{ |
||||
Console.WriteLine(e); |
||||
throw; |
||||
} |
||||
} |
||||
|
||||
[RelayCommand] |
||||
private void OpenOnCivitAi() |
||||
{ |
||||
ProcessRunner.OpenUrl($"https://civitai.com/models/{ConnectedModel.ModelId}"); |
||||
} |
||||
|
||||
// Loads image from path |
||||
private async Task LoadPreviewImage() |
||||
{ |
||||
if (PreviewImagePath == null) return; |
||||
var bytes = await File.ReadAllBytesAsync(PreviewImagePath); |
||||
await Application.Current.Dispatcher.InvokeAsync(() => |
||||
{ |
||||
var bitmap = new BitmapImage(); |
||||
using var ms = new MemoryStream(bytes); |
||||
bitmap.BeginInit(); |
||||
bitmap.StreamSource = ms; |
||||
bitmap.CacheOption = BitmapCacheOption.OnLoad; |
||||
bitmap.EndInit(); |
||||
PreviewImage = bitmap; |
||||
}); |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Indexes directory and yields all checkpoint files. |
||||
/// First we match all files with supported extensions. |
||||
/// If found, we also look for |
||||
/// - {filename}.preview.{image-extensions} (preview image) |
||||
/// - {filename}.cm-info.json (connected model info) |
||||
/// </summary> |
||||
public static IEnumerable<CheckpointFile> FromDirectoryIndex(IDialogFactory dialogFactory, string directory, SearchOption searchOption = SearchOption.TopDirectoryOnly) |
||||
{ |
||||
// Get all files with supported extensions |
||||
var allExtensions = SupportedCheckpointExtensions |
||||
.Concat(SupportedImageExtensions) |
||||
.Concat(SupportedMetadataExtensions); |
||||
|
||||
var files = allExtensions.AsParallel() |
||||
.SelectMany(pattern => Directory.EnumerateFiles(directory, $"*{pattern}", searchOption)).ToDictionary<string, string>(Path.GetFileName); |
||||
|
||||
foreach (var file in files.Keys.Where(k => SupportedCheckpointExtensions.Contains(Path.GetExtension(k)))) |
||||
{ |
||||
var checkpointFile = new CheckpointFile(dialogFactory) |
||||
{ |
||||
Title = Path.GetFileNameWithoutExtension(file), |
||||
FilePath = Path.Combine(directory, file), |
||||
}; |
||||
|
||||
// Check for connected model info |
||||
var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(file); |
||||
var cmInfoPath = $"{fileNameWithoutExtension}.cm-info.json"; |
||||
if (files.TryGetValue(cmInfoPath, out var jsonPath)) |
||||
{ |
||||
try |
||||
{ |
||||
var jsonData = File.ReadAllText(jsonPath); |
||||
checkpointFile.ConnectedModel = ConnectedModelInfo.FromJson(jsonData); |
||||
} |
||||
catch (IOException e) |
||||
{ |
||||
Debug.WriteLine($"Failed to parse {cmInfoPath}: {e}"); |
||||
} |
||||
} |
||||
|
||||
// Check for preview image |
||||
var previewImage = SupportedImageExtensions.Select(ext => $"{fileNameWithoutExtension}.preview{ext}").FirstOrDefault(files.ContainsKey); |
||||
if (previewImage != null) |
||||
{ |
||||
checkpointFile.PreviewImagePath = files[previewImage]; |
||||
checkpointFile.LoadPreviewImage().SafeFireAndForget(); |
||||
} |
||||
|
||||
yield return checkpointFile; |
||||
} |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Index with progress reporting. |
||||
/// </summary> |
||||
public static IEnumerable<CheckpointFile> FromDirectoryIndex(IDialogFactory dialogFactory, string directory, IProgress<ProgressReport> progress, |
||||
SearchOption searchOption = SearchOption.TopDirectoryOnly) |
||||
{ |
||||
var current = 0ul; |
||||
foreach (var checkpointFile in FromDirectoryIndex(dialogFactory, directory, searchOption)) |
||||
{ |
||||
current++; |
||||
progress.Report(new ProgressReport(current, "Indexing", checkpointFile.FileName)); |
||||
yield return checkpointFile; |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,332 @@
|
||||
using System; |
||||
using System.Collections.Generic; |
||||
using System.Collections.ObjectModel; |
||||
using System.Collections.Specialized; |
||||
using System.Diagnostics; |
||||
using System.IO; |
||||
using System.Linq; |
||||
using System.Threading.Tasks; |
||||
using System.Windows; |
||||
using CommunityToolkit.Mvvm.ComponentModel; |
||||
using CommunityToolkit.Mvvm.Input; |
||||
using StabilityMatrix.Core.Extensions; |
||||
using StabilityMatrix.Core.Helper; |
||||
using StabilityMatrix.Core.Models; |
||||
using StabilityMatrix.Core.Models.FileInterfaces; |
||||
using StabilityMatrix.Core.Models.Progress; |
||||
using StabilityMatrix.Core.Services; |
||||
using StabilityMatrix.Helper; |
||||
using StabilityMatrix.ViewModels; |
||||
|
||||
namespace StabilityMatrix.Models; |
||||
|
||||
public partial class CheckpointFolder : ObservableObject |
||||
{ |
||||
private readonly IDialogFactory dialogFactory; |
||||
private readonly ISettingsManager settingsManager; |
||||
private readonly IDownloadService downloadService; |
||||
private readonly ModelFinder modelFinder; |
||||
// ReSharper disable once FieldCanBeMadeReadOnly.Local |
||||
private bool useCategoryVisibility; |
||||
|
||||
/// <summary> |
||||
/// Absolute path to the folder. |
||||
/// </summary> |
||||
public string DirectoryPath { get; init; } = string.Empty; |
||||
|
||||
/// <summary> |
||||
/// Custom title for UI. |
||||
/// </summary> |
||||
[ObservableProperty] |
||||
[NotifyPropertyChangedFor(nameof(FolderType))] |
||||
[NotifyPropertyChangedFor(nameof(TitleWithFilesCount))] |
||||
private string title = string.Empty; |
||||
|
||||
[ObservableProperty] |
||||
private SharedFolderType folderType; |
||||
|
||||
/// <summary> |
||||
/// True if the category is enabled for the manager page. |
||||
/// </summary> |
||||
[ObservableProperty] |
||||
private bool isCategoryEnabled = true; |
||||
|
||||
[ObservableProperty] |
||||
[NotifyPropertyChangedFor(nameof(IsDragBlurEnabled))] |
||||
private bool isCurrentDragTarget; |
||||
|
||||
[ObservableProperty] |
||||
[NotifyPropertyChangedFor(nameof(IsDragBlurEnabled))] |
||||
private bool isImportInProgress; |
||||
|
||||
public bool IsDragBlurEnabled => IsCurrentDragTarget || IsImportInProgress; |
||||
public string TitleWithFilesCount => CheckpointFiles.Any() ? $"{Title} ({CheckpointFiles.Count})" : Title; |
||||
|
||||
public ProgressViewModel Progress { get; } = new(); |
||||
|
||||
public ObservableCollection<CheckpointFolder> SubFolders { get; init; } = new(); |
||||
public ObservableCollection<CheckpointFile> CheckpointFiles { get; init; } = new(); |
||||
|
||||
public RelayCommand OnPreviewDragEnterCommand => new(() => IsCurrentDragTarget = true); |
||||
public RelayCommand OnPreviewDragLeaveCommand => new(() => IsCurrentDragTarget = false); |
||||
|
||||
public CheckpointFolder( |
||||
IDialogFactory dialogFactory, |
||||
ISettingsManager settingsManager, |
||||
IDownloadService downloadService, |
||||
ModelFinder modelFinder, |
||||
bool useCategoryVisibility = true) |
||||
{ |
||||
this.dialogFactory = dialogFactory; |
||||
this.settingsManager = settingsManager; |
||||
this.downloadService = downloadService; |
||||
this.modelFinder = modelFinder; |
||||
this.useCategoryVisibility = useCategoryVisibility; |
||||
|
||||
CheckpointFiles.CollectionChanged += OnCheckpointFilesChanged; |
||||
} |
||||
|
||||
/// <summary> |
||||
/// When title is set, set the category enabled state from settings. |
||||
/// </summary> |
||||
// ReSharper disable once UnusedParameterInPartialMethod |
||||
partial void OnTitleChanged(string value) |
||||
{ |
||||
if (!useCategoryVisibility) return; |
||||
|
||||
// Update folder type |
||||
var result = Enum.TryParse(Title, out SharedFolderType type); |
||||
FolderType = result ? type : new SharedFolderType(); |
||||
|
||||
IsCategoryEnabled = settingsManager.IsSharedFolderCategoryVisible(FolderType); |
||||
} |
||||
|
||||
/// <summary> |
||||
/// When toggling the category enabled state, save it to settings. |
||||
/// </summary> |
||||
partial void OnIsCategoryEnabledChanged(bool value) |
||||
{ |
||||
if (!useCategoryVisibility) return; |
||||
if (value != settingsManager.IsSharedFolderCategoryVisible(FolderType)) |
||||
{ |
||||
settingsManager.SetSharedFolderCategoryVisible(FolderType, value); |
||||
} |
||||
} |
||||
|
||||
// On collection changes |
||||
private void OnCheckpointFilesChanged(object? sender, NotifyCollectionChangedEventArgs e) |
||||
{ |
||||
OnPropertyChanged(nameof(TitleWithFilesCount)); |
||||
if (e.NewItems == null) return; |
||||
// On new added items, add event handler for deletion |
||||
foreach (CheckpointFile item in e.NewItems) |
||||
{ |
||||
item.Deleted += OnCheckpointFileDelete; |
||||
} |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Handler for CheckpointFile requesting to be deleted from the collection. |
||||
/// </summary> |
||||
/// <param name="sender"></param> |
||||
/// <param name="file"></param> |
||||
private void OnCheckpointFileDelete(object? sender, CheckpointFile file) |
||||
{ |
||||
Application.Current.Dispatcher.Invoke(() => CheckpointFiles.Remove(file)); |
||||
} |
||||
|
||||
[RelayCommand] |
||||
private async Task OnPreviewDropAsync(DragEventArgs e) |
||||
{ |
||||
IsImportInProgress = true; |
||||
IsCurrentDragTarget = false; |
||||
|
||||
if (e.Data.GetData(DataFormats.FileDrop) is not string[] files || files.Length < 1) |
||||
{ |
||||
IsImportInProgress = false; |
||||
return; |
||||
} |
||||
|
||||
await ImportFilesAsync(files, settingsManager.Settings.IsImportAsConnected); |
||||
} |
||||
|
||||
[RelayCommand] |
||||
private void ShowInExplorer(string path) |
||||
{ |
||||
Process.Start("explorer.exe", path); |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Imports files to the folder. Reports progress to instance properties. |
||||
/// </summary> |
||||
public async Task ImportFilesAsync(IEnumerable<string> files, bool convertToConnected = false) |
||||
{ |
||||
Progress.IsIndeterminate = true; |
||||
Progress.IsProgressVisible = true; |
||||
var copyPaths = files.ToDictionary(k => k, v => Path.Combine(DirectoryPath, Path.GetFileName(v))); |
||||
|
||||
var progress = new Progress<ProgressReport>(report => |
||||
{ |
||||
Progress.IsIndeterminate = false; |
||||
Progress.Value = report.Percentage; |
||||
// For multiple files, add count |
||||
Progress.Text = copyPaths.Count > 1 ? $"Importing {report.Title} ({report.Message})" : $"Importing {report.Title}"; |
||||
}); |
||||
|
||||
await FileTransfers.CopyFiles(copyPaths, progress); |
||||
|
||||
// Hash files and convert them to connected model if found |
||||
if (convertToConnected) |
||||
{ |
||||
var modelFilesCount = copyPaths.Count; |
||||
var modelFiles = copyPaths.Values |
||||
.Select(path => new FilePath(path)); |
||||
|
||||
// Holds tasks for model queries after hash |
||||
var modelQueryTasks = new List<Task<bool>>(); |
||||
|
||||
foreach (var (i, modelFile) in modelFiles.Enumerate()) |
||||
{ |
||||
var hashProgress = new Progress<ProgressReport>(report => |
||||
{ |
||||
Progress.IsIndeterminate = false; |
||||
Progress.Value = report.Percentage; |
||||
Progress.Text = modelFilesCount > 1 ? |
||||
$"Computing metadata for {modelFile.Info.Name} ({i}/{modelFilesCount})" : |
||||
$"Computing metadata for {report.Title}"; |
||||
}); |
||||
|
||||
var hashBlake3 = await FileHash.GetBlake3Async(modelFile, hashProgress); |
||||
|
||||
// Start a task to query the model in background |
||||
var queryTask = Task.Run(async () => |
||||
{ |
||||
var result = await modelFinder.LocalFindModel(hashBlake3); |
||||
result ??= await modelFinder.RemoteFindModel(hashBlake3); |
||||
|
||||
if (result is null) return false; // Not found |
||||
|
||||
var (model, version, file) = result.Value; |
||||
|
||||
// Save connected model info json |
||||
var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Info.Name); |
||||
var modelInfo = new ConnectedModelInfo( |
||||
model, version, file, DateTimeOffset.UtcNow); |
||||
await modelInfo.SaveJsonToDirectory(DirectoryPath, modelFileName); |
||||
|
||||
// If available, save thumbnail |
||||
var image = version.Images?.FirstOrDefault(); |
||||
if (image != null) |
||||
{ |
||||
var imageExt = Path.GetExtension(image.Url).TrimStart('.'); |
||||
if (imageExt is "jpg" or "jpeg" or "png") |
||||
{ |
||||
var imageDownloadPath = Path.GetFullPath( |
||||
Path.Combine(DirectoryPath, $"{modelFileName}.preview.{imageExt}")); |
||||
await downloadService.DownloadToFileAsync(image.Url, imageDownloadPath); |
||||
} |
||||
} |
||||
|
||||
return true; |
||||
}); |
||||
modelQueryTasks.Add(queryTask); |
||||
} |
||||
|
||||
// Set progress to indeterminate |
||||
Progress.IsIndeterminate = true; |
||||
Progress.Text = "Checking connected model information"; |
||||
|
||||
// Wait for all model queries to finish |
||||
var modelQueryResults = await Task.WhenAll(modelQueryTasks); |
||||
|
||||
var successCount = modelQueryResults.Count(r => r); |
||||
var totalCount = modelQueryResults.Length; |
||||
var failCount = totalCount - successCount; |
||||
|
||||
await IndexAsync(); |
||||
|
||||
Progress.Value = 100; |
||||
Progress.Text = successCount switch |
||||
{ |
||||
0 when failCount > 0 => |
||||
"Import complete. No connected data found.", |
||||
> 0 when failCount > 0 => |
||||
$"Import complete. Found connected data for {successCount} of {totalCount} models.", |
||||
_ => $"Import complete. Found connected data for all {totalCount} models." |
||||
}; |
||||
|
||||
DelayedClearProgress(TimeSpan.FromSeconds(1)); |
||||
} |
||||
else |
||||
{ |
||||
await IndexAsync(); |
||||
Progress.Value = 100; |
||||
Progress.Text = "Import complete"; |
||||
DelayedClearProgress(TimeSpan.FromSeconds(1)); |
||||
} |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Clears progress after a delay. |
||||
/// </summary> |
||||
private void DelayedClearProgress(TimeSpan delay) |
||||
{ |
||||
Task.Delay(delay).ContinueWith(_ => |
||||
{ |
||||
IsImportInProgress = false; |
||||
Progress.IsProgressVisible = false; |
||||
Progress.Value = 0; |
||||
Progress.Text = string.Empty; |
||||
}); |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Gets checkpoint files from folder index |
||||
/// </summary> |
||||
private async Task<List<CheckpointFile>> GetCheckpointFilesAsync(IProgress<ProgressReport>? progress = default) |
||||
{ |
||||
if (!Directory.Exists(DirectoryPath)) |
||||
{ |
||||
return new List<CheckpointFile>(); |
||||
} |
||||
|
||||
return await (progress switch |
||||
{ |
||||
null => Task.Run(() => |
||||
CheckpointFile.FromDirectoryIndex(dialogFactory, DirectoryPath).ToList()), |
||||
|
||||
_ => Task.Run(() => |
||||
CheckpointFile.FromDirectoryIndex(dialogFactory, DirectoryPath, progress).ToList()) |
||||
}); |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Indexes the folder for checkpoint files and refreshes the CheckPointFiles collection. |
||||
/// </summary> |
||||
public async Task IndexAsync(IProgress<ProgressReport>? progress = default) |
||||
{ |
||||
SubFolders.Clear(); |
||||
foreach (var folder in Directory.GetDirectories(DirectoryPath)) |
||||
{ |
||||
// Inherit our folder type |
||||
var subFolder = new CheckpointFolder(dialogFactory, settingsManager, |
||||
downloadService, modelFinder, |
||||
useCategoryVisibility: false) |
||||
{ |
||||
Title = Path.GetFileName(folder), |
||||
DirectoryPath = folder, |
||||
FolderType = FolderType |
||||
}; |
||||
|
||||
await subFolder.IndexAsync(progress); |
||||
SubFolders.Add(subFolder); |
||||
} |
||||
|
||||
var checkpointFiles = await GetCheckpointFilesAsync(); |
||||
CheckpointFiles.Clear(); |
||||
foreach (var checkpointFile in checkpointFiles) |
||||
{ |
||||
CheckpointFiles.Add(checkpointFile); |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue