using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.IO; using System.Linq; using System.Threading.Tasks; using Avalonia.Data; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using NLog; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager; [ManagedService] [Transient] public partial class CheckpointFile : ViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); /// /// Absolute path to the checkpoint file. /// [ObservableProperty, NotifyPropertyChangedFor(nameof(FileName))] private string filePath = string.Empty; /// /// Custom title for UI. /// [ObservableProperty] private string title = string.Empty; /// /// Path to preview image. Can be local or remote URL. /// [ObservableProperty] private string? previewImagePath; [ObservableProperty] [NotifyPropertyChangedFor(nameof(IsConnectedModel))] private ConnectedModelInfo? connectedModel; public bool IsConnectedModel => ConnectedModel != null; [ObservableProperty] private bool isLoading; [ObservableProperty] private CivitModelType modelType; [ObservableProperty] private CheckpointFolder parentFolder; [ObservableProperty] private ProgressReport? progress; public string FileName => Path.GetFileName(FilePath); public bool CanShowTriggerWords => ConnectedModel != null && !string.IsNullOrWhiteSpace(ConnectedModel.TrainedWordsString); public ObservableCollection Badges { get; set; } = new(); public static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt", ".ckpt", ".pth", ".bin" }; private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg" }; private static readonly string[] SupportedMetadataExtensions = { ".json" }; partial void OnConnectedModelChanged(ConnectedModelInfo? value) { // Update title, first check user defined, then connected model name Title = value?.UserTitle ?? value?.ModelName ?? string.Empty; // Update badges Badges.Clear(); var fpType = value?.FileMetadata.Fp?.GetStringValue().ToUpperInvariant(); if (fpType != null) { Badges.Add(fpType); } if (!string.IsNullOrWhiteSpace(value?.BaseModel)) { Badges.Add(value.BaseModel); } } private string GetConnectedModelInfoFilePath() { if (string.IsNullOrEmpty(FilePath)) { throw new InvalidOperationException( "Cannot get connected model info file path when FilePath is empty" ); } var modelNameNoExt = Path.GetFileNameWithoutExtension((string?)FilePath); var modelDir = Path.GetDirectoryName((string?)FilePath) ?? ""; return Path.Combine(modelDir, $"{modelNameNoExt}.cm-info.json"); } [RelayCommand] private async Task DeleteAsync() { if (File.Exists(FilePath)) { IsLoading = true; try { await using var delay = new MinimumDelay(200, 500); await Task.Run(() => File.Delete(FilePath)); if (PreviewImagePath != null && File.Exists(PreviewImagePath)) { await Task.Run(() => File.Delete(PreviewImagePath)); } if (ConnectedModel != null) { var cmInfoPath = GetConnectedModelInfoFilePath(); if (File.Exists(cmInfoPath)) { await Task.Run(() => File.Delete(cmInfoPath)); } } } catch (IOException ex) { Logger.Warn($"Failed to delete checkpoint file {FilePath}: {ex.Message}"); return; // Don't delete from collection } finally { IsLoading = false; } } RemoveFromParentList(); } public void OnMoved() => RemoveFromParentList(); [RelayCommand] private async Task RenameAsync() { // Parent folder path var parentPath = Path.GetDirectoryName((string?)FilePath) ?? ""; var textFields = new TextBoxField[] { new() { Label = "File name", Validator = text => { if (string.IsNullOrWhiteSpace(text)) throw new DataValidationException("File name is required"); if (File.Exists(Path.Combine(parentPath, text))) throw new DataValidationException("File name already exists"); }, Text = FileName } }; var dialog = DialogHelper.CreateTextEntryDialog("Rename Model", "", textFields); if (await dialog.ShowAsync() == ContentDialogResult.Primary) { var name = textFields[0].Text; var nameNoExt = Path.GetFileNameWithoutExtension(name); var originalNameNoExt = Path.GetFileNameWithoutExtension(FileName); // Rename file in OS try { var newFilePath = Path.Combine(parentPath, name); File.Move(FilePath, newFilePath); FilePath = newFilePath; // If preview image exists, rename it too if (PreviewImagePath != null && File.Exists(PreviewImagePath)) { var newPreviewImagePath = Path.Combine( parentPath, $"{nameNoExt}.preview{Path.GetExtension((string?)PreviewImagePath)}" ); File.Move(PreviewImagePath, newPreviewImagePath); PreviewImagePath = newPreviewImagePath; } // If connected model info exists, rename it too (.cm-info.json) if (ConnectedModel != null) { var cmInfoPath = Path.Combine(parentPath, $"{originalNameNoExt}.cm-info.json"); if (File.Exists(cmInfoPath)) { File.Move( cmInfoPath, Path.Combine(parentPath, $"{nameNoExt}.cm-info.json") ); } } } catch (Exception e) { Logger.Warn(e, $"Failed to rename checkpoint file {FilePath}"); } } } [RelayCommand] private void OpenOnCivitAi() { if (ConnectedModel?.ModelId == null) return; ProcessRunner.OpenUrl($"https://civitai.com/models/{ConnectedModel.ModelId}"); } [RelayCommand] private Task CopyTriggerWords() { if (ConnectedModel == null) return Task.CompletedTask; var words = ConnectedModel.TrainedWordsString; if (string.IsNullOrWhiteSpace(words)) return Task.CompletedTask; return App.Clipboard.SetTextAsync(words); } [RelayCommand] private async Task FindConnectedMetadata(bool forceReimport = false) { if ( App.Services.GetService(typeof(IMetadataImportService)) is not IMetadataImportService importService ) return; IsLoading = true; try { var progressReport = new Progress(report => { Progress = report; }); var cmInfo = await importService.GetMetadataForFile( FilePath, progressReport, forceReimport ); if (cmInfo != null) { ConnectedModel = cmInfo; PreviewImagePath = SupportedImageExtensions .Select( ext => Path.Combine( ParentFolder.DirectoryPath, $"{Path.GetFileNameWithoutExtension(FileName)}.preview{ext}" ) ) .Where(File.Exists) .FirstOrDefault(); } } finally { IsLoading = false; } } /// /// Indexes directory and yields all checkpoint files. /// First we match all files with supported extensions. /// If found, we also look for /// - {filename}.preview.{image-extensions} (preview image) /// - {filename}.cm-info.json (connected model info) /// public static IEnumerable FromDirectoryIndex( CheckpointFolder parentFolder, string directory, SearchOption searchOption = SearchOption.TopDirectoryOnly ) { foreach (var file in Directory.EnumerateFiles(directory, "*.*", searchOption)) { if ( !SupportedCheckpointExtensions.Any( ext => Path.GetExtension(file) .Equals(ext, StringComparison.InvariantCultureIgnoreCase) ) ) continue; var checkpointFile = new CheckpointFile { Title = Path.GetFileNameWithoutExtension(file), FilePath = Path.Combine(directory, file), }; var jsonPath = Path.Combine( directory, $"{Path.GetFileNameWithoutExtension(file)}.cm-info.json" ); if (File.Exists(jsonPath)) { var json = File.ReadAllText(jsonPath); var connectedModelInfo = ConnectedModelInfo.FromJson(json); checkpointFile.ConnectedModel = connectedModelInfo; } checkpointFile.PreviewImagePath = SupportedImageExtensions .Select( ext => Path.Combine( directory, $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}" ) ) .Where(File.Exists) .FirstOrDefault(); if (string.IsNullOrWhiteSpace(checkpointFile.PreviewImagePath)) { checkpointFile.PreviewImagePath = Assets.NoImage.ToString(); } checkpointFile.ParentFolder = parentFolder; yield return checkpointFile; } } public static IEnumerable GetAllCheckpointFiles(string modelsDirectory) { foreach ( var file in Directory.EnumerateFiles( modelsDirectory, "*.*", SearchOption.AllDirectories ) ) { if ( !SupportedCheckpointExtensions.Any( ext => Path.GetExtension(file) .Equals(ext, StringComparison.InvariantCultureIgnoreCase) ) ) continue; var checkpointFile = new CheckpointFile { Title = Path.GetFileNameWithoutExtension(file), FilePath = file }; var jsonPath = Path.Combine( Path.GetDirectoryName(file) ?? "", Path.GetFileNameWithoutExtension(file) + ".cm-info.json" ); if (File.Exists(jsonPath)) { var json = File.ReadAllText(jsonPath); var connectedModelInfo = ConnectedModelInfo.FromJson(json); checkpointFile.ConnectedModel = connectedModelInfo; checkpointFile.ModelType = GetCivitModelType(file); } checkpointFile.PreviewImagePath = SupportedImageExtensions .Select( ext => Path.Combine( Path.GetDirectoryName(file) ?? "", $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}" ) ) .Where(File.Exists) .FirstOrDefault(); yield return checkpointFile; } } /// /// Index with progress reporting. /// public static IEnumerable FromDirectoryIndex( CheckpointFolder parentFolder, string directory, IProgress progress, SearchOption searchOption = SearchOption.TopDirectoryOnly ) { var current = 0ul; foreach (var checkpointFile in FromDirectoryIndex(parentFolder, directory, searchOption)) { current++; progress.Report(new ProgressReport(current, "Indexing", checkpointFile.FileName)); yield return checkpointFile; } } private static CivitModelType GetCivitModelType(string filePath) { if (filePath.Contains(SharedFolderType.StableDiffusion.ToString())) { return CivitModelType.Checkpoint; } if (filePath.Contains(SharedFolderType.ControlNet.ToString())) { return CivitModelType.Controlnet; } if (filePath.Contains(SharedFolderType.Lora.ToString())) { return CivitModelType.LORA; } if (filePath.Contains(SharedFolderType.TextualInversion.ToString())) { return CivitModelType.TextualInversion; } if (filePath.Contains(SharedFolderType.Hypernetwork.ToString())) { return CivitModelType.Hypernetwork; } if (filePath.Contains(SharedFolderType.LyCORIS.ToString())) { return CivitModelType.LoCon; } return CivitModelType.Unknown; } private sealed class FilePathEqualityComparer : IEqualityComparer { public bool Equals(CheckpointFile? x, CheckpointFile? y) { if (ReferenceEquals(x, y)) return true; if (ReferenceEquals(x, null)) return false; if (ReferenceEquals(y, null)) return false; if (x.GetType() != y.GetType()) return false; return x.FilePath == y.FilePath && x.ConnectedModel?.Hashes.BLAKE3 == y.ConnectedModel?.Hashes.BLAKE3 && x.ConnectedModel?.ThumbnailImageUrl == y.ConnectedModel?.ThumbnailImageUrl && x.PreviewImagePath == y.PreviewImagePath; } public int GetHashCode(CheckpointFile obj) { return obj.FilePath.GetHashCode(); } } public static IEqualityComparer FilePathComparer { get; } = new FilePathEqualityComparer(); }