You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
7.8 KiB
230 lines
7.8 KiB
using System; |
|
using System.Collections.Generic; |
|
using System.Collections.ObjectModel; |
|
using System.Diagnostics; |
|
using System.IO; |
|
using System.Linq; |
|
using System.Threading.Tasks; |
|
using AsyncAwaitBestPractices; |
|
using Avalonia.Data; |
|
using Avalonia.Media.Imaging; |
|
using Avalonia.Threading; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using CommunityToolkit.Mvvm.Input; |
|
using FluentAvalonia.UI.Controls; |
|
using NLog; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Helper; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.Progress; |
|
using StabilityMatrix.Core.Processes; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels; |
|
|
|
public partial class CheckpointFile : ViewModelBase |
|
{ |
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
|
|
// Event for when this file is deleted |
|
public event EventHandler<CheckpointFile>? Deleted; |
|
|
|
/// <summary> |
|
/// Absolute path to the checkpoint file. |
|
/// </summary> |
|
[ObservableProperty, NotifyPropertyChangedFor(nameof(FileName))] |
|
private string filePath = string.Empty; |
|
|
|
/// <summary> |
|
/// Custom title for UI. |
|
/// </summary> |
|
[ObservableProperty] |
|
private string title = string.Empty; |
|
|
|
public string? PreviewImagePath { get; set; } |
|
public Bitmap? PreviewImage { get; set; } |
|
public bool IsPreviewImageLoaded => PreviewImage != null; |
|
|
|
[ObservableProperty] |
|
private ConnectedModelInfo? connectedModel; |
|
public bool IsConnectedModel => ConnectedModel != null; |
|
|
|
[ObservableProperty] private bool isLoading; |
|
|
|
public string FileName => Path.GetFileName((string?) FilePath); |
|
|
|
public ObservableCollection<string> Badges { get; set; } = new(); |
|
|
|
private static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt", ".ckpt", ".pth", "bin" }; |
|
private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg" }; |
|
private static readonly string[] SupportedMetadataExtensions = { ".json" }; |
|
|
|
partial void OnConnectedModelChanged(ConnectedModelInfo? value) |
|
{ |
|
// Update title, first check user defined, then connected model name |
|
Title = value?.UserTitle ?? value?.ModelName ?? string.Empty; |
|
// Update badges |
|
Badges.Clear(); |
|
var fpType = value?.FileMetadata.Fp?.GetStringValue().ToUpperInvariant(); |
|
if (fpType != null) |
|
{ |
|
Badges.Add(fpType); |
|
} |
|
if (!string.IsNullOrWhiteSpace(value?.BaseModel)) |
|
{ |
|
Badges.Add(value.BaseModel); |
|
} |
|
} |
|
|
|
[RelayCommand] |
|
private async Task DeleteAsync() |
|
{ |
|
if (File.Exists(FilePath)) |
|
{ |
|
IsLoading = true; |
|
try |
|
{ |
|
await using var delay = new MinimumDelay(200, 500); |
|
await Task.Run(() => File.Delete(FilePath)); |
|
if (PreviewImagePath != null && File.Exists(PreviewImagePath)) |
|
{ |
|
await Task.Run(() => File.Delete(PreviewImagePath)); |
|
} |
|
} |
|
catch (IOException ex) |
|
{ |
|
Logger.Warn($"Failed to delete checkpoint file {FilePath}: {ex.Message}"); |
|
return; // Don't delete from collection |
|
} |
|
finally |
|
{ |
|
IsLoading = false; |
|
} |
|
} |
|
Deleted?.Invoke(this, this); |
|
} |
|
|
|
[RelayCommand] |
|
private async Task RenameAsync() |
|
{ |
|
// Parent folder path |
|
var parentPath = Path.GetDirectoryName(FilePath) ?? ""; |
|
|
|
var textFields = new TextBoxField[] |
|
{ |
|
new() |
|
{ |
|
Label = "File name", |
|
Validator = text => |
|
{ |
|
if (string.IsNullOrWhiteSpace(text)) throw new |
|
DataValidationException("File name is required"); |
|
|
|
if (File.Exists(Path.Combine(parentPath, text))) throw new |
|
DataValidationException("File name already exists"); |
|
} |
|
} |
|
}; |
|
|
|
var dialog = DialogHelper.CreateTextEntryDialog("Rename Model", "", textFields); |
|
|
|
if (await dialog.ShowAsync() == ContentDialogResult.Primary) |
|
{ |
|
var name = textFields[0].Text; |
|
// Rename file in OS |
|
try |
|
{ |
|
var newFilePath = Path.Combine(parentPath, name); |
|
File.Move(FilePath, newFilePath); |
|
FilePath = newFilePath; |
|
} |
|
catch (Exception e) |
|
{ |
|
Logger.Warn(e, $"Failed to rename checkpoint file {FilePath}"); |
|
} |
|
} |
|
} |
|
|
|
[RelayCommand] |
|
private void OpenOnCivitAi() |
|
{ |
|
ProcessRunner.OpenUrl($"https://civitai.com/models/{ConnectedModel.ModelId}"); |
|
} |
|
|
|
// Loads image from path |
|
private async Task LoadPreviewImage() |
|
{ |
|
if (PreviewImagePath == null) return; |
|
await Dispatcher.UIThread.InvokeAsync(() => |
|
{ |
|
PreviewImage = new Bitmap(File.OpenRead(PreviewImagePath)); |
|
}); |
|
} |
|
|
|
/// <summary> |
|
/// Indexes directory and yields all checkpoint files. |
|
/// First we match all files with supported extensions. |
|
/// If found, we also look for |
|
/// - {filename}.preview.{image-extensions} (preview image) |
|
/// - {filename}.cm-info.json (connected model info) |
|
/// </summary> |
|
public static IEnumerable<CheckpointFile> FromDirectoryIndex(string directory, SearchOption searchOption = SearchOption.TopDirectoryOnly) |
|
{ |
|
// Get all files with supported extensions |
|
var allExtensions = SupportedCheckpointExtensions |
|
.Concat(SupportedImageExtensions) |
|
.Concat(SupportedMetadataExtensions); |
|
|
|
var files = allExtensions.AsParallel() |
|
.SelectMany(pattern => Directory.EnumerateFiles(directory, $"*{pattern}", searchOption)).ToDictionary<string, string>(Path.GetFileName); |
|
|
|
foreach (var file in files.Keys.Where(k => SupportedCheckpointExtensions.Contains(Path.GetExtension(k)))) |
|
{ |
|
var checkpointFile = new CheckpointFile() |
|
{ |
|
Title = Path.GetFileNameWithoutExtension(file), |
|
FilePath = Path.Combine(directory, file), |
|
}; |
|
|
|
// Check for connected model info |
|
var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(file); |
|
var cmInfoPath = $"{fileNameWithoutExtension}.cm-info.json"; |
|
if (files.TryGetValue(cmInfoPath, out var jsonPath)) |
|
{ |
|
try |
|
{ |
|
var jsonData = File.ReadAllText(jsonPath); |
|
checkpointFile.ConnectedModel = ConnectedModelInfo.FromJson(jsonData); |
|
} |
|
catch (IOException e) |
|
{ |
|
Debug.WriteLine($"Failed to parse {cmInfoPath}: {e}"); |
|
} |
|
} |
|
|
|
// Check for preview image |
|
var previewImage = SupportedImageExtensions.Select(ext => $"{fileNameWithoutExtension}.preview{ext}").FirstOrDefault(files.ContainsKey); |
|
if (previewImage != null) |
|
{ |
|
checkpointFile.PreviewImagePath = files[previewImage]; |
|
checkpointFile.LoadPreviewImage().SafeFireAndForget(); |
|
} |
|
|
|
yield return checkpointFile; |
|
} |
|
} |
|
|
|
/// <summary> |
|
/// Index with progress reporting. |
|
/// </summary> |
|
public static IEnumerable<CheckpointFile> FromDirectoryIndex(string directory, IProgress<ProgressReport> progress, |
|
SearchOption searchOption = SearchOption.TopDirectoryOnly) |
|
{ |
|
var current = 0ul; |
|
foreach (var checkpointFile in FromDirectoryIndex(directory, searchOption)) |
|
{ |
|
current++; |
|
progress.Report(new ProgressReport(current, "Indexing", checkpointFile.FileName)); |
|
yield return checkpointFile; |
|
} |
|
} |
|
}
|
|
|