You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
463 lines
15 KiB
463 lines
15 KiB
using System; |
|
using System.Collections.Generic; |
|
using System.Collections.ObjectModel; |
|
using System.IO; |
|
using System.Linq; |
|
using System.Threading.Tasks; |
|
using Avalonia.Data; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using CommunityToolkit.Mvvm.Input; |
|
using FluentAvalonia.UI.Controls; |
|
using NLog; |
|
using StabilityMatrix.Avalonia.Languages; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Helper; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.Api; |
|
using StabilityMatrix.Core.Models.Progress; |
|
using StabilityMatrix.Core.Processes; |
|
using StabilityMatrix.Core.Services; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager; |
|
|
|
[ManagedService] |
|
[Transient] |
|
public partial class CheckpointFile : ViewModelBase |
|
{ |
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
|
|
/// <summary> |
|
/// Absolute path to the checkpoint file. |
|
/// </summary> |
|
[ObservableProperty, NotifyPropertyChangedFor(nameof(FileName))] |
|
private string filePath = string.Empty; |
|
|
|
/// <summary> |
|
/// Custom title for UI. |
|
/// </summary> |
|
[ObservableProperty] |
|
private string title = string.Empty; |
|
|
|
/// <summary> |
|
/// Path to preview image. Can be local or remote URL. |
|
/// </summary> |
|
[ObservableProperty] |
|
private string? previewImagePath; |
|
|
|
[ObservableProperty] |
|
[NotifyPropertyChangedFor(nameof(IsConnectedModel))] |
|
private ConnectedModelInfo? connectedModel; |
|
public bool IsConnectedModel => ConnectedModel != null; |
|
|
|
[ObservableProperty] |
|
private bool isLoading; |
|
|
|
[ObservableProperty] |
|
private CivitModelType modelType; |
|
|
|
[ObservableProperty] |
|
private CheckpointFolder parentFolder; |
|
|
|
[ObservableProperty] |
|
private ProgressReport? progress; |
|
|
|
public string FileName => Path.GetFileName(FilePath); |
|
|
|
public bool CanShowTriggerWords => |
|
ConnectedModel != null && !string.IsNullOrWhiteSpace(ConnectedModel.TrainedWordsString); |
|
|
|
public ObservableCollection<string> Badges { get; set; } = new(); |
|
|
|
public static readonly string[] SupportedCheckpointExtensions = |
|
{ |
|
".safetensors", |
|
".pt", |
|
".ckpt", |
|
".pth", |
|
".bin" |
|
}; |
|
private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg", ".gif" }; |
|
private static readonly string[] SupportedMetadataExtensions = { ".json" }; |
|
|
|
partial void OnConnectedModelChanged(ConnectedModelInfo? value) |
|
{ |
|
// Update title, first check user defined, then connected model name |
|
Title = value?.UserTitle ?? value?.ModelName ?? string.Empty; |
|
// Update badges |
|
Badges.Clear(); |
|
var fpType = value?.FileMetadata.Fp?.GetStringValue().ToUpperInvariant(); |
|
if (fpType != null) |
|
{ |
|
Badges.Add(fpType); |
|
} |
|
if (!string.IsNullOrWhiteSpace(value?.BaseModel)) |
|
{ |
|
Badges.Add(value.BaseModel); |
|
} |
|
} |
|
|
|
private string GetConnectedModelInfoFilePath() |
|
{ |
|
if (string.IsNullOrEmpty(FilePath)) |
|
{ |
|
throw new InvalidOperationException( |
|
"Cannot get connected model info file path when FilePath is empty" |
|
); |
|
} |
|
var modelNameNoExt = Path.GetFileNameWithoutExtension((string?)FilePath); |
|
var modelDir = Path.GetDirectoryName((string?)FilePath) ?? ""; |
|
return Path.Combine(modelDir, $"{modelNameNoExt}.cm-info.json"); |
|
} |
|
|
|
[RelayCommand] |
|
private async Task DeleteAsync() |
|
{ |
|
if (File.Exists(FilePath)) |
|
{ |
|
IsLoading = true; |
|
try |
|
{ |
|
await using var delay = new MinimumDelay(200, 500); |
|
await Task.Run(() => File.Delete(FilePath)); |
|
if (PreviewImagePath != null && File.Exists(PreviewImagePath)) |
|
{ |
|
await Task.Run(() => File.Delete(PreviewImagePath)); |
|
} |
|
if (ConnectedModel != null) |
|
{ |
|
var cmInfoPath = GetConnectedModelInfoFilePath(); |
|
if (File.Exists(cmInfoPath)) |
|
{ |
|
await Task.Run(() => File.Delete(cmInfoPath)); |
|
} |
|
} |
|
} |
|
catch (IOException ex) |
|
{ |
|
Logger.Warn($"Failed to delete checkpoint file {FilePath}: {ex.Message}"); |
|
return; // Don't delete from collection |
|
} |
|
finally |
|
{ |
|
IsLoading = false; |
|
} |
|
} |
|
RemoveFromParentList(); |
|
} |
|
|
|
public void OnMoved() => RemoveFromParentList(); |
|
|
|
[RelayCommand] |
|
private async Task RenameAsync() |
|
{ |
|
// Parent folder path |
|
var parentPath = Path.GetDirectoryName((string?)FilePath) ?? ""; |
|
|
|
var textFields = new TextBoxField[] |
|
{ |
|
new() |
|
{ |
|
Label = "File name", |
|
Validator = text => |
|
{ |
|
if (string.IsNullOrWhiteSpace(text)) |
|
throw new DataValidationException("File name is required"); |
|
|
|
if (File.Exists(Path.Combine(parentPath, text))) |
|
throw new DataValidationException("File name already exists"); |
|
}, |
|
Text = FileName |
|
} |
|
}; |
|
|
|
var dialog = DialogHelper.CreateTextEntryDialog("Rename Model", "", textFields); |
|
|
|
if (await dialog.ShowAsync() == ContentDialogResult.Primary) |
|
{ |
|
var name = textFields[0].Text; |
|
var nameNoExt = Path.GetFileNameWithoutExtension(name); |
|
var originalNameNoExt = Path.GetFileNameWithoutExtension(FileName); |
|
// Rename file in OS |
|
try |
|
{ |
|
var newFilePath = Path.Combine(parentPath, name); |
|
File.Move(FilePath, newFilePath); |
|
FilePath = newFilePath; |
|
// If preview image exists, rename it too |
|
if (PreviewImagePath != null && File.Exists(PreviewImagePath)) |
|
{ |
|
var newPreviewImagePath = Path.Combine( |
|
parentPath, |
|
$"{nameNoExt}.preview{Path.GetExtension((string?)PreviewImagePath)}" |
|
); |
|
File.Move(PreviewImagePath, newPreviewImagePath); |
|
PreviewImagePath = newPreviewImagePath; |
|
} |
|
// If connected model info exists, rename it too (<name>.cm-info.json) |
|
if (ConnectedModel != null) |
|
{ |
|
var cmInfoPath = Path.Combine(parentPath, $"{originalNameNoExt}.cm-info.json"); |
|
if (File.Exists(cmInfoPath)) |
|
{ |
|
File.Move(cmInfoPath, Path.Combine(parentPath, $"{nameNoExt}.cm-info.json")); |
|
} |
|
} |
|
} |
|
catch (Exception e) |
|
{ |
|
Logger.Warn(e, $"Failed to rename checkpoint file {FilePath}"); |
|
} |
|
} |
|
} |
|
|
|
[RelayCommand] |
|
private void OpenOnCivitAi() |
|
{ |
|
if (ConnectedModel?.ModelId == null) |
|
return; |
|
ProcessRunner.OpenUrl($"https://civitai.com/models/{ConnectedModel.ModelId}"); |
|
} |
|
|
|
[RelayCommand] |
|
private Task CopyTriggerWords() |
|
{ |
|
if (ConnectedModel == null) |
|
return Task.CompletedTask; |
|
|
|
var words = ConnectedModel.TrainedWordsString; |
|
if (string.IsNullOrWhiteSpace(words)) |
|
return Task.CompletedTask; |
|
|
|
return App.Clipboard.SetTextAsync(words); |
|
} |
|
|
|
[RelayCommand] |
|
private Task CopyModelUrl() |
|
{ |
|
return ConnectedModel == null |
|
? Task.CompletedTask |
|
: App.Clipboard.SetTextAsync($"https://civitai.com/models/{ConnectedModel.ModelId}"); |
|
} |
|
|
|
[RelayCommand] |
|
private async Task FindConnectedMetadata(bool forceReimport = false) |
|
{ |
|
if ( |
|
App.Services.GetService(typeof(IMetadataImportService)) |
|
is not IMetadataImportService importService |
|
) |
|
return; |
|
|
|
IsLoading = true; |
|
|
|
try |
|
{ |
|
var progressReport = new Progress<ProgressReport>(report => |
|
{ |
|
Progress = report; |
|
}); |
|
|
|
var cmInfo = await importService.GetMetadataForFile(FilePath, progressReport, forceReimport); |
|
if (cmInfo != null) |
|
{ |
|
ConnectedModel = cmInfo; |
|
PreviewImagePath = SupportedImageExtensions |
|
.Select( |
|
ext => |
|
Path.Combine( |
|
ParentFolder.DirectoryPath, |
|
$"{Path.GetFileNameWithoutExtension(FileName)}.preview{ext}" |
|
) |
|
) |
|
.Where(File.Exists) |
|
.FirstOrDefault(); |
|
} |
|
} |
|
finally |
|
{ |
|
IsLoading = false; |
|
} |
|
} |
|
|
|
/// <summary> |
|
/// Indexes directory and yields all checkpoint files. |
|
/// First we match all files with supported extensions. |
|
/// If found, we also look for |
|
/// - {filename}.preview.{image-extensions} (preview image) |
|
/// - {filename}.cm-info.json (connected model info) |
|
/// </summary> |
|
public static IEnumerable<CheckpointFile> FromDirectoryIndex( |
|
CheckpointFolder parentFolder, |
|
string directory, |
|
SearchOption searchOption = SearchOption.TopDirectoryOnly |
|
) |
|
{ |
|
foreach (var file in Directory.EnumerateFiles(directory, "*.*", searchOption)) |
|
{ |
|
if ( |
|
!SupportedCheckpointExtensions.Any( |
|
ext => Path.GetExtension(file).Equals(ext, StringComparison.InvariantCultureIgnoreCase) |
|
) |
|
) |
|
continue; |
|
|
|
var checkpointFile = new CheckpointFile |
|
{ |
|
Title = Path.GetFileNameWithoutExtension(file), |
|
FilePath = Path.Combine(directory, file), |
|
}; |
|
|
|
var jsonPath = Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(file)}.cm-info.json"); |
|
if (File.Exists(jsonPath)) |
|
{ |
|
var json = File.ReadAllText(jsonPath); |
|
var connectedModelInfo = ConnectedModelInfo.FromJson(json); |
|
checkpointFile.ConnectedModel = connectedModelInfo; |
|
} |
|
|
|
checkpointFile.PreviewImagePath = SupportedImageExtensions |
|
.Select( |
|
ext => Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}") |
|
) |
|
.Where(File.Exists) |
|
.FirstOrDefault(); |
|
|
|
if (string.IsNullOrWhiteSpace(checkpointFile.PreviewImagePath)) |
|
{ |
|
checkpointFile.PreviewImagePath = Assets.NoImage.ToString(); |
|
} |
|
|
|
checkpointFile.ParentFolder = parentFolder; |
|
|
|
yield return checkpointFile; |
|
} |
|
} |
|
|
|
public static IEnumerable<CheckpointFile> GetAllCheckpointFiles(string modelsDirectory) |
|
{ |
|
foreach (var file in Directory.EnumerateFiles(modelsDirectory, "*.*", SearchOption.AllDirectories)) |
|
{ |
|
if ( |
|
!SupportedCheckpointExtensions.Any( |
|
ext => Path.GetExtension(file).Equals(ext, StringComparison.InvariantCultureIgnoreCase) |
|
) |
|
) |
|
continue; |
|
|
|
var checkpointFile = new CheckpointFile |
|
{ |
|
Title = Path.GetFileNameWithoutExtension(file), |
|
FilePath = file |
|
}; |
|
|
|
var jsonPath = Path.Combine( |
|
Path.GetDirectoryName(file) ?? "", |
|
Path.GetFileNameWithoutExtension(file) + ".cm-info.json" |
|
); |
|
|
|
if (File.Exists(jsonPath)) |
|
{ |
|
var json = File.ReadAllText(jsonPath); |
|
var connectedModelInfo = ConnectedModelInfo.FromJson(json); |
|
checkpointFile.ConnectedModel = connectedModelInfo; |
|
checkpointFile.ModelType = GetCivitModelType(file); |
|
} |
|
|
|
checkpointFile.PreviewImagePath = SupportedImageExtensions |
|
.Select( |
|
ext => |
|
Path.Combine( |
|
Path.GetDirectoryName(file) ?? "", |
|
$"{Path.GetFileNameWithoutExtension(file)}.preview{ext}" |
|
) |
|
) |
|
.Where(File.Exists) |
|
.FirstOrDefault(); |
|
|
|
yield return checkpointFile; |
|
} |
|
} |
|
|
|
/// <summary> |
|
/// Index with progress reporting. |
|
/// </summary> |
|
public static IEnumerable<CheckpointFile> FromDirectoryIndex( |
|
CheckpointFolder parentFolder, |
|
string directory, |
|
IProgress<ProgressReport> progress, |
|
SearchOption searchOption = SearchOption.TopDirectoryOnly |
|
) |
|
{ |
|
var current = 0ul; |
|
foreach (var checkpointFile in FromDirectoryIndex(parentFolder, directory, searchOption)) |
|
{ |
|
current++; |
|
progress.Report(new ProgressReport(current, "Indexing", checkpointFile.FileName)); |
|
yield return checkpointFile; |
|
} |
|
} |
|
|
|
private static CivitModelType GetCivitModelType(string filePath) |
|
{ |
|
if (filePath.Contains(SharedFolderType.StableDiffusion.ToString())) |
|
{ |
|
return CivitModelType.Checkpoint; |
|
} |
|
|
|
if (filePath.Contains(SharedFolderType.ControlNet.ToString())) |
|
{ |
|
return CivitModelType.Controlnet; |
|
} |
|
|
|
if (filePath.Contains(SharedFolderType.Lora.ToString())) |
|
{ |
|
return CivitModelType.LORA; |
|
} |
|
|
|
if (filePath.Contains(SharedFolderType.TextualInversion.ToString())) |
|
{ |
|
return CivitModelType.TextualInversion; |
|
} |
|
|
|
if (filePath.Contains(SharedFolderType.Hypernetwork.ToString())) |
|
{ |
|
return CivitModelType.Hypernetwork; |
|
} |
|
|
|
if (filePath.Contains(SharedFolderType.LyCORIS.ToString())) |
|
{ |
|
return CivitModelType.LoCon; |
|
} |
|
|
|
return CivitModelType.Unknown; |
|
} |
|
|
|
private sealed class FilePathEqualityComparer : IEqualityComparer<CheckpointFile> |
|
{ |
|
public bool Equals(CheckpointFile? x, CheckpointFile? y) |
|
{ |
|
if (ReferenceEquals(x, y)) |
|
return true; |
|
if (ReferenceEquals(x, null)) |
|
return false; |
|
if (ReferenceEquals(y, null)) |
|
return false; |
|
if (x.GetType() != y.GetType()) |
|
return false; |
|
return x.FilePath == y.FilePath |
|
&& x.ConnectedModel?.Hashes.BLAKE3 == y.ConnectedModel?.Hashes.BLAKE3 |
|
&& x.ConnectedModel?.ThumbnailImageUrl == y.ConnectedModel?.ThumbnailImageUrl |
|
&& x.PreviewImagePath == y.PreviewImagePath; |
|
} |
|
|
|
public int GetHashCode(CheckpointFile obj) |
|
{ |
|
return obj.FilePath.GetHashCode(); |
|
} |
|
} |
|
|
|
public static IEqualityComparer<CheckpointFile> FilePathComparer { get; } = |
|
new FilePathEqualityComparer(); |
|
}
|
|
|