Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

437 lines
15 KiB

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Avalonia.Data;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using NLog;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
[ManagedService]
[Transient]
public partial class CheckpointFile : ViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
/// <summary>
/// Absolute path to the checkpoint file.
/// </summary>
[ObservableProperty, NotifyPropertyChangedFor(nameof(FileName))]
private string filePath = string.Empty;
/// <summary>
/// Custom title for UI.
/// </summary>
[ObservableProperty]
private string title = string.Empty;
/// <summary>
/// Path to preview image. Can be local or remote URL.
/// </summary>
[ObservableProperty]
private string? previewImagePath;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsConnectedModel))]
private ConnectedModelInfo? connectedModel;
public bool IsConnectedModel => ConnectedModel != null;
[ObservableProperty]
private bool isLoading;
[ObservableProperty]
private CivitModelType modelType;
[ObservableProperty]
private CheckpointFolder parentFolder;
[ObservableProperty]
private ProgressReport? progress;
public string FileName => Path.GetFileName(FilePath);
public bool CanShowTriggerWords =>
ConnectedModel != null && !string.IsNullOrWhiteSpace(ConnectedModel.TrainedWordsString);
public ObservableCollection<string> Badges { get; set; } = new();
public static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt", ".ckpt", ".pth", ".bin" };
private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg", ".gif" };
private static readonly string[] SupportedMetadataExtensions = { ".json" };
partial void OnConnectedModelChanged(ConnectedModelInfo? value)
{
// Update title, first check user defined, then connected model name
Title = value?.UserTitle ?? value?.ModelName ?? string.Empty;
// Update badges
Badges.Clear();
var fpType = value?.FileMetadata.Fp?.GetStringValue().ToUpperInvariant();
if (fpType != null)
{
Badges.Add(fpType);
}
if (!string.IsNullOrWhiteSpace(value?.BaseModel))
{
Badges.Add(value.BaseModel);
}
}
private string GetConnectedModelInfoFilePath()
{
if (string.IsNullOrEmpty(FilePath))
{
throw new InvalidOperationException("Cannot get connected model info file path when FilePath is empty");
}
var modelNameNoExt = Path.GetFileNameWithoutExtension((string?)FilePath);
var modelDir = Path.GetDirectoryName((string?)FilePath) ?? "";
return Path.Combine(modelDir, $"{modelNameNoExt}.cm-info.json");
}
[RelayCommand]
private async Task DeleteAsync()
{
if (File.Exists(FilePath))
{
IsLoading = true;
try
{
await using var delay = new MinimumDelay(200, 500);
await Task.Run(() => File.Delete(FilePath));
if (PreviewImagePath != null && File.Exists(PreviewImagePath))
{
await Task.Run(() => File.Delete(PreviewImagePath));
}
if (ConnectedModel != null)
{
var cmInfoPath = GetConnectedModelInfoFilePath();
if (File.Exists(cmInfoPath))
{
await Task.Run(() => File.Delete(cmInfoPath));
}
}
}
catch (IOException ex)
{
Logger.Warn($"Failed to delete checkpoint file {FilePath}: {ex.Message}");
return; // Don't delete from collection
}
finally
{
IsLoading = false;
}
}
RemoveFromParentList();
}
public void OnMoved() => RemoveFromParentList();
[RelayCommand]
private async Task RenameAsync()
{
// Parent folder path
var parentPath = Path.GetDirectoryName((string?)FilePath) ?? "";
var textFields = new TextBoxField[]
{
new()
{
Label = "File name",
Validator = text =>
{
if (string.IsNullOrWhiteSpace(text))
throw new DataValidationException("File name is required");
if (File.Exists(Path.Combine(parentPath, text)))
throw new DataValidationException("File name already exists");
},
Text = FileName
}
};
var dialog = DialogHelper.CreateTextEntryDialog("Rename Model", "", textFields);
if (await dialog.ShowAsync() == ContentDialogResult.Primary)
{
var name = textFields[0].Text;
var nameNoExt = Path.GetFileNameWithoutExtension(name);
var originalNameNoExt = Path.GetFileNameWithoutExtension(FileName);
// Rename file in OS
try
{
var newFilePath = Path.Combine(parentPath, name);
File.Move(FilePath, newFilePath);
FilePath = newFilePath;
// If preview image exists, rename it too
if (PreviewImagePath != null && File.Exists(PreviewImagePath))
{
var newPreviewImagePath = Path.Combine(
parentPath,
$"{nameNoExt}.preview{Path.GetExtension((string?)PreviewImagePath)}"
);
File.Move(PreviewImagePath, newPreviewImagePath);
PreviewImagePath = newPreviewImagePath;
}
// If connected model info exists, rename it too (<name>.cm-info.json)
if (ConnectedModel != null)
{
var cmInfoPath = Path.Combine(parentPath, $"{originalNameNoExt}.cm-info.json");
if (File.Exists(cmInfoPath))
{
File.Move(cmInfoPath, Path.Combine(parentPath, $"{nameNoExt}.cm-info.json"));
}
}
}
catch (Exception e)
{
Logger.Warn(e, $"Failed to rename checkpoint file {FilePath}");
}
}
}
[RelayCommand]
private void OpenOnCivitAi()
{
if (ConnectedModel?.ModelId == null)
return;
ProcessRunner.OpenUrl($"https://civitai.com/models/{ConnectedModel.ModelId}");
}
[RelayCommand]
private Task CopyTriggerWords()
{
if (ConnectedModel == null)
return Task.CompletedTask;
var words = ConnectedModel.TrainedWordsString;
if (string.IsNullOrWhiteSpace(words))
return Task.CompletedTask;
return App.Clipboard.SetTextAsync(words);
}
[RelayCommand]
private async Task FindConnectedMetadata(bool forceReimport = false)
{
if (App.Services.GetService(typeof(IMetadataImportService)) is not IMetadataImportService importService)
return;
IsLoading = true;
try
{
var progressReport = new Progress<ProgressReport>(report =>
{
Progress = report;
});
var cmInfo = await importService.GetMetadataForFile(FilePath, progressReport, forceReimport);
if (cmInfo != null)
{
ConnectedModel = cmInfo;
PreviewImagePath = SupportedImageExtensions
.Select(
ext =>
Path.Combine(
ParentFolder.DirectoryPath,
$"{Path.GetFileNameWithoutExtension(FileName)}.preview{ext}"
)
)
.Where(File.Exists)
.FirstOrDefault();
}
}
finally
{
IsLoading = false;
}
}
/// <summary>
/// Indexes directory and yields all checkpoint files.
/// First we match all files with supported extensions.
/// If found, we also look for
/// - {filename}.preview.{image-extensions} (preview image)
/// - {filename}.cm-info.json (connected model info)
/// </summary>
public static IEnumerable<CheckpointFile> FromDirectoryIndex(
CheckpointFolder parentFolder,
string directory,
SearchOption searchOption = SearchOption.TopDirectoryOnly
)
{
foreach (var file in Directory.EnumerateFiles(directory, "*.*", searchOption))
{
if (
!SupportedCheckpointExtensions.Any(
ext => Path.GetExtension(file).Equals(ext, StringComparison.InvariantCultureIgnoreCase)
)
)
continue;
var checkpointFile = new CheckpointFile
{
Title = Path.GetFileNameWithoutExtension(file),
FilePath = Path.Combine(directory, file),
};
var jsonPath = Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(file)}.cm-info.json");
if (File.Exists(jsonPath))
{
var json = File.ReadAllText(jsonPath);
var connectedModelInfo = ConnectedModelInfo.FromJson(json);
checkpointFile.ConnectedModel = connectedModelInfo;
}
1 year ago
checkpointFile.PreviewImagePath = SupportedImageExtensions
.Select(ext => Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}"))
.Where(File.Exists)
1 year ago
.FirstOrDefault();
if (string.IsNullOrWhiteSpace(checkpointFile.PreviewImagePath))
{
checkpointFile.PreviewImagePath = Assets.NoImage.ToString();
}
checkpointFile.ParentFolder = parentFolder;
yield return checkpointFile;
}
}
public static IEnumerable<CheckpointFile> GetAllCheckpointFiles(string modelsDirectory)
{
foreach (var file in Directory.EnumerateFiles(modelsDirectory, "*.*", SearchOption.AllDirectories))
{
if (
!SupportedCheckpointExtensions.Any(
ext => Path.GetExtension(file).Equals(ext, StringComparison.InvariantCultureIgnoreCase)
)
)
continue;
var checkpointFile = new CheckpointFile { Title = Path.GetFileNameWithoutExtension(file), FilePath = file };
var jsonPath = Path.Combine(
Path.GetDirectoryName(file) ?? "",
Path.GetFileNameWithoutExtension(file) + ".cm-info.json"
);
if (File.Exists(jsonPath))
{
var json = File.ReadAllText(jsonPath);
var connectedModelInfo = ConnectedModelInfo.FromJson(json);
checkpointFile.ConnectedModel = connectedModelInfo;
checkpointFile.ModelType = GetCivitModelType(file);
}
checkpointFile.PreviewImagePath = SupportedImageExtensions
.Select(
ext =>
Path.Combine(
Path.GetDirectoryName(file) ?? "",
$"{Path.GetFileNameWithoutExtension(file)}.preview{ext}"
)
)
.Where(File.Exists)
.FirstOrDefault();
yield return checkpointFile;
}
}
/// <summary>
/// Index with progress reporting.
/// </summary>
public static IEnumerable<CheckpointFile> FromDirectoryIndex(
CheckpointFolder parentFolder,
string directory,
IProgress<ProgressReport> progress,
SearchOption searchOption = SearchOption.TopDirectoryOnly
)
{
var current = 0ul;
foreach (var checkpointFile in FromDirectoryIndex(parentFolder, directory, searchOption))
{
current++;
progress.Report(new ProgressReport(current, "Indexing", checkpointFile.FileName));
yield return checkpointFile;
}
}
private static CivitModelType GetCivitModelType(string filePath)
{
if (filePath.Contains(SharedFolderType.StableDiffusion.ToString()))
{
return CivitModelType.Checkpoint;
}
if (filePath.Contains(SharedFolderType.ControlNet.ToString()))
{
return CivitModelType.Controlnet;
}
if (filePath.Contains(SharedFolderType.Lora.ToString()))
{
return CivitModelType.LORA;
}
if (filePath.Contains(SharedFolderType.TextualInversion.ToString()))
{
return CivitModelType.TextualInversion;
}
if (filePath.Contains(SharedFolderType.Hypernetwork.ToString()))
{
return CivitModelType.Hypernetwork;
}
if (filePath.Contains(SharedFolderType.LyCORIS.ToString()))
{
return CivitModelType.LoCon;
}
return CivitModelType.Unknown;
}
private sealed class FilePathEqualityComparer : IEqualityComparer<CheckpointFile>
{
public bool Equals(CheckpointFile? x, CheckpointFile? y)
{
if (ReferenceEquals(x, y))
return true;
if (ReferenceEquals(x, null))
return false;
if (ReferenceEquals(y, null))
return false;
if (x.GetType() != y.GetType())
return false;
return x.FilePath == y.FilePath
&& x.ConnectedModel?.Hashes.BLAKE3 == y.ConnectedModel?.Hashes.BLAKE3
&& x.ConnectedModel?.ThumbnailImageUrl == y.ConnectedModel?.ThumbnailImageUrl
&& x.PreviewImagePath == y.PreviewImagePath;
}
public int GetHashCode(CheckpointFile obj)
{
return obj.FilePath.GetHashCode();
}
}
public static IEqualityComparer<CheckpointFile> FilePathComparer { get; } = new FilePathEqualityComparer();
}