diff --git a/StabilityMatrix/Api/ICivitApi.cs b/StabilityMatrix/Api/ICivitApi.cs new file mode 100644 index 00000000..dd618f10 --- /dev/null +++ b/StabilityMatrix/Api/ICivitApi.cs @@ -0,0 +1,14 @@ +using System.Threading.Tasks; +using Refit; +using StabilityMatrix.Models.Api; + +namespace StabilityMatrix.Api; + +public interface ICivitApi +{ + [Get("/api/v1/models")] + Task GetModels(CivitModelsRequest request); + + [Get("/api/v1/model-versions/by-hash")] + Task GetModelVersionByHash([Query] string hash); +} diff --git a/StabilityMatrix/App.xaml.cs b/StabilityMatrix/App.xaml.cs index feb43770..477d923d 100644 --- a/StabilityMatrix/App.xaml.cs +++ b/StabilityMatrix/App.xaml.cs @@ -74,8 +74,9 @@ namespace StabilityMatrix serviceCollection.AddTransient(); serviceCollection.AddTransient(); serviceCollection.AddTransient(); + serviceCollection.AddTransient(); serviceCollection.AddTransient(); - + serviceCollection.AddTransient(); serviceCollection.AddTransient(); serviceCollection.AddTransient(); @@ -86,6 +87,7 @@ namespace StabilityMatrix serviceCollection.AddTransient(); serviceCollection.AddTransient(); serviceCollection.AddTransient(); + serviceCollection.AddSingleton(); var settingsManager = new SettingsManager(); serviceCollection.AddSingleton(settingsManager); @@ -136,6 +138,13 @@ namespace StabilityMatrix c.Timeout = TimeSpan.FromSeconds(2); }) .AddPolicyHandler(retryPolicy); + serviceCollection.AddRefitClient(defaultRefitSettings) + .ConfigureHttpClient(c => + { + c.BaseAddress = new Uri("https://civitai.com"); + c.Timeout = TimeSpan.FromSeconds(8); + }) + .AddPolicyHandler(retryPolicy); // Logging configuration var logPath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "log.txt"); diff --git a/StabilityMatrix/CheckpointBrowserPage.xaml b/StabilityMatrix/CheckpointBrowserPage.xaml new file mode 100644 index 00000000..19ba9b4a --- /dev/null +++ b/StabilityMatrix/CheckpointBrowserPage.xaml @@ -0,0 +1,209 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix/CheckpointBrowserPage.xaml.cs b/StabilityMatrix/CheckpointBrowserPage.xaml.cs new file mode 100644 index 00000000..6a3a2971 --- /dev/null +++ b/StabilityMatrix/CheckpointBrowserPage.xaml.cs @@ -0,0 +1,36 @@ +using System.Diagnostics; +using System.Threading; +using System.Windows; +using System.Windows.Controls; +using System.Windows.Input; +using System.Windows.Media; +using System.Windows.Media.Effects; +using StabilityMatrix.ViewModels; +using Wpf.Ui.Controls; + +namespace StabilityMatrix; + +public partial class CheckpointBrowserPage : Page +{ + public CheckpointBrowserPage(CheckpointBrowserViewModel viewModel) + { + InitializeComponent(); + DataContext = viewModel; + } + + private void VirtualizingGridView_OnPreviewMouseWheel(object sender, MouseWheelEventArgs e) + { + if (e.Handled) return; + + e.Handled = true; + var eventArg = new MouseWheelEventArgs(e.MouseDevice, e.Timestamp, e.Delta) + { + RoutedEvent = MouseWheelEvent, + Source = sender + }; + if (((Control)sender).Parent is UIElement parent) + { + parent.RaiseEvent(eventArg); + } + } +} diff --git a/StabilityMatrix/CheckpointManagerPage.xaml b/StabilityMatrix/CheckpointManagerPage.xaml index 62209d12..4b331c1e 100644 --- a/StabilityMatrix/CheckpointManagerPage.xaml +++ b/StabilityMatrix/CheckpointManagerPage.xaml @@ -9,8 +9,10 @@ mc:Ignorable="d" x:Class="StabilityMatrix.CheckpointManagerPage" xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation" + xmlns:converters="clr-namespace:StabilityMatrix.Converters" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:designData="clr-namespace:StabilityMatrix.DesignData" + xmlns:i="http://schemas.microsoft.com/xaml/behaviors" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:models="clr-namespace:StabilityMatrix.Models" xmlns:ui="http://schemas.lepo.co/wpfui/2022/xaml" @@ -18,31 +20,122 @@ xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"> + + + + + + + - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -51,12 +144,88 @@ Header="{Binding Title}" IsExpanded="True" Margin="8"> - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -82,7 +251,7 @@ HorizontalAlignment="Stretch" ItemTemplate="{StaticResource CheckpointFolderGridDataTemplate}" ItemsSource="{Binding CheckpointFolders, Mode=OneWay}" - Margin="16,16,16,16" /> + Margin="8" /> diff --git a/StabilityMatrix/Converters/BooleanToHiddenVisibleConverter.cs b/StabilityMatrix/Converters/BooleanToHiddenVisibleConverter.cs new file mode 100644 index 00000000..13e69d46 --- /dev/null +++ b/StabilityMatrix/Converters/BooleanToHiddenVisibleConverter.cs @@ -0,0 +1,32 @@ +using System; +using System.Globalization; +using System.Windows; +using System.Windows.Data; + +namespace StabilityMatrix.Converters; + +public class BooleanToHiddenVisibleConverter : IValueConverter +{ + public object Convert(object value, Type targetType, object parameter, CultureInfo culture) + { + var bValue = false; + if (value is bool b) + { + bValue = b; + } + else if (value is bool) + { + var tmp = (bool?) value; + bValue = tmp.Value; + } + return bValue ? Visibility.Visible : Visibility.Hidden; + } + public object ConvertBack(object value, Type targetType, object parameter, CultureInfo culture) + { + if (value is Visibility visibility) + { + return visibility == Visibility.Visible; + } + return false; + } +} diff --git a/StabilityMatrix/Converters/UriToBitmapConverter.cs b/StabilityMatrix/Converters/UriToBitmapConverter.cs index e0f8958a..b081248e 100644 --- a/StabilityMatrix/Converters/UriToBitmapConverter.cs +++ b/StabilityMatrix/Converters/UriToBitmapConverter.cs @@ -14,6 +14,11 @@ public class UriToBitmapConverter : IValueConverter return new BitmapImage(uri); } + if (value is string uriString) + { + return new BitmapImage(new Uri(uriString)); + } + return null; } diff --git a/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs b/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs new file mode 100644 index 00000000..94940ebf --- /dev/null +++ b/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs @@ -0,0 +1,39 @@ +using System.Collections.ObjectModel; +using System.ComponentModel; +using StabilityMatrix.Models.Api; +using StabilityMatrix.ViewModels; + +namespace StabilityMatrix.DesignData; + +[DesignOnly(true)] +public class MockCheckpointBrowserViewModel : CheckpointBrowserViewModel +{ + public MockCheckpointBrowserViewModel() : base(null!, null!, null!) + { + ModelCards = new ObservableCollection + { + new (null!, null!, null!) + { + CivitModel = new() + { + Name = "bb95 Furry Mix", + ModelVersions = new[] + { + new CivitModelVersion + { + Name = "v7.0", + Images = new[] + { + new CivitImage + { + Url = + "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/1547f350-461a-4cd0-a753-0544aa81e4fc/width=450/00000-4137473915.jpeg" + } + } + } + } + } + } + }; + } +} diff --git a/StabilityMatrix/DesignData/MockCheckpointManagerViewModel.cs b/StabilityMatrix/DesignData/MockCheckpointManagerViewModel.cs index 82421219..877880e5 100644 --- a/StabilityMatrix/DesignData/MockCheckpointManagerViewModel.cs +++ b/StabilityMatrix/DesignData/MockCheckpointManagerViewModel.cs @@ -40,6 +40,7 @@ public class MockCheckpointManagerViewModel : CheckpointManagerViewModel new() { Title = "Lora", + IsCurrentDragTarget = true, CheckpointFiles = new() { new() diff --git a/StabilityMatrix/Extensions/EnumAttributes.cs b/StabilityMatrix/Extensions/EnumAttributes.cs new file mode 100644 index 00000000..e1d388ec --- /dev/null +++ b/StabilityMatrix/Extensions/EnumAttributes.cs @@ -0,0 +1,56 @@ +using System; +using System.Linq; +using System.Windows.Ink; + +namespace StabilityMatrix.Extensions; + +public static class EnumAttributeExtensions +{ + private static T? GetAttributeValue(Enum value) + { + var type = value.GetType(); + var fieldInfo = type.GetField(value.ToString()); + // Get the string value attributes + var attribs = fieldInfo?.GetCustomAttributes(typeof(T), false) as T[]; + // Return the first if there was a match. + return attribs?.Length > 0 ? attribs[0] : default; + } + /// + /// Gets the StringValue field attribute on a given enum value. + /// If not found, returns the enum value itself as a string. + /// + /// + /// + public static string GetStringValue(this Enum value) + { + var attr = GetAttributeValue(value)?.StringValue; + return attr ?? Enum.GetName(value.GetType(), value)!; + } + /// + /// Gets the Description field attribute on a given enum value. + /// + /// + /// + public static string? GetDescription(this Enum value) + { + return GetAttributeValue(value)?.Description; + } +} + +[AttributeUsage(AttributeTargets.Field)] +public sealed class StringValueAttribute : Attribute +{ + public string StringValue { get; } + public StringValueAttribute(string value) { + StringValue = value; + } +} + +[AttributeUsage(AttributeTargets.Field)] +public sealed class DescriptionAttribute : Attribute +{ + public string Description { get; } + public DescriptionAttribute(string value) { + Description = value; + } +} diff --git a/StabilityMatrix/Extensions/EnumConversion.cs b/StabilityMatrix/Extensions/EnumConversion.cs new file mode 100644 index 00000000..10500508 --- /dev/null +++ b/StabilityMatrix/Extensions/EnumConversion.cs @@ -0,0 +1,26 @@ +using System; + +namespace StabilityMatrix.Extensions; + +public static class EnumConversionExtensions +{ + public static T? ConvertTo(this Enum value) where T : Enum + { + var type = value.GetType(); + var fieldInfo = type.GetField(value.ToString()); + // Get the string value attributes + var attribs = fieldInfo?.GetCustomAttributes(typeof(ConvertToAttribute), false) as ConvertToAttribute[]; + // Return the first if there was a match. + return attribs?.Length > 0 ? attribs[0].ConvertToEnum : default; + } +} + +[AttributeUsage(AttributeTargets.Field)] +public sealed class ConvertToAttribute : Attribute where T : Enum +{ + public T ConvertToEnum { get; } + public ConvertToAttribute(T toEnum) + { + ConvertToEnum = toEnum; + } +} diff --git a/StabilityMatrix/Helper/FileHash.cs b/StabilityMatrix/Helper/FileHash.cs new file mode 100644 index 00000000..bfc10c09 --- /dev/null +++ b/StabilityMatrix/Helper/FileHash.cs @@ -0,0 +1,61 @@ +using System; +using System.Buffers; +using System.IO; +using System.Security.Cryptography; +using System.Threading.Tasks; +using StabilityMatrix.Models; + +namespace StabilityMatrix.Helper; + +public static class FileHash +{ + public static async Task GetHashAsync(HashAlgorithm hashAlgorithm, Stream stream, byte[] buffer, Action? progress = default) + { + ulong totalBytesRead = 0; + + using (hashAlgorithm) + { + int bytesRead; + while ((bytesRead = await stream.ReadAsync(buffer)) != 0) + { + totalBytesRead += (ulong) bytesRead; + hashAlgorithm.TransformBlock(buffer, 0, bytesRead, null, 0); + progress?.Invoke(totalBytesRead); + } + hashAlgorithm.TransformFinalBlock(buffer, 0, 0); + var hash = hashAlgorithm.Hash; + if (hash == null || hash.Length == 0) + { + throw new InvalidOperationException("Hash algorithm did not produce a hash."); + } + return BitConverter.ToString(hash).Replace("-", string.Empty).ToLowerInvariant(); + } + } + + public static async Task GetSha256Async(string filePath, IProgress? progress = default) + { + if (!File.Exists(filePath)) + { + throw new FileNotFoundException($"Could not find file: {filePath}"); + } + + var totalBytes = Convert.ToUInt64(new FileInfo(filePath).Length); + var shared = ArrayPool.Shared; + var buffer = shared.Rent((int) FileTransfers.GetBufferSize(totalBytes)); + try + { + await using var stream = File.OpenRead(filePath); + + var hash = await GetHashAsync(SHA256.Create(), stream, buffer, totalBytesRead => + { + progress?.Report(new ProgressReport(totalBytesRead, totalBytes)); + }); + return hash; + } + finally + { + shared.Return(buffer); + } + + } +} diff --git a/StabilityMatrix/Helper/FileTransfers.cs b/StabilityMatrix/Helper/FileTransfers.cs new file mode 100644 index 00000000..dc1384da --- /dev/null +++ b/StabilityMatrix/Helper/FileTransfers.cs @@ -0,0 +1,76 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using StabilityMatrix.Models; + +namespace StabilityMatrix.Helper; + +public static class FileTransfers +{ + /// + /// Determines suitable buffer size based on stream length. + /// + /// + /// + public static ulong GetBufferSize(ulong totalBytes) => totalBytes switch + { + < Size.MiB => 8 * Size.KiB, + < 100 * Size.MiB => 16 * Size.KiB, + < 500 * Size.MiB => Size.MiB, + < Size.GiB => 16 * Size.MiB, + _ => 32 * Size.MiB + }; + + public static async Task CopyFiles(Dictionary files, IProgress? fileProgress = default, IProgress? totalProgress = default) + { + var totalFiles = files.Count; + var currentFiles = 0; + var totalSize = Convert.ToUInt64(files.Keys.Select(x => new FileInfo(x).Length).Sum()); + var totalRead = 0ul; + + foreach(var (sourcePath, destPath) in files) + { + var totalReadForFile = 0ul; + + await using var outStream = new FileStream(destPath, FileMode.Create, FileAccess.Write, FileShare.Read); + await using var inStream = new FileStream(sourcePath, FileMode.Open, FileAccess.Read, FileShare.Read); + var fileSize = (ulong) inStream.Length; + var fileName = Path.GetFileName(sourcePath); + currentFiles++; + await CopyStream(inStream , outStream, fileReadBytes => + { + var lastRead = totalReadForFile; + totalReadForFile = Convert.ToUInt64(fileReadBytes); + totalRead += totalReadForFile - lastRead; + fileProgress?.Report(new ProgressReport(totalReadForFile, fileSize, fileName, $"{currentFiles}/{totalFiles}")); + totalProgress?.Report(new ProgressReport(totalRead, totalSize, fileName, $"{currentFiles}/{totalFiles}")); + } ); + } + } + + private static async Task CopyStream(Stream from, Stream to, Action progress) + { + var shared = ArrayPool.Shared; + var bufferSize = (int) GetBufferSize((ulong) from.Length); + var buffer = shared.Rent(bufferSize); + var totalRead = 0L; + + try + { + while (totalRead < from.Length) + { + var read = await from.ReadAsync(buffer.AsMemory(0, bufferSize)); + await to.WriteAsync(buffer.AsMemory(0, read)); + totalRead += read; + progress(totalRead); + } + } + finally + { + shared.Return(buffer); + } + } +} diff --git a/StabilityMatrix/Helper/PrerequisiteHelper.cs b/StabilityMatrix/Helper/PrerequisiteHelper.cs index b104091c..4f903eb2 100644 --- a/StabilityMatrix/Helper/PrerequisiteHelper.cs +++ b/StabilityMatrix/Helper/PrerequisiteHelper.cs @@ -59,13 +59,13 @@ public class PrerequisiteHelper : IPrerequisiteHelper if (!File.Exists(PortableGitDownloadPath)) { - downloadService.DownloadProgressChanged += OnDownloadProgressChanged; - downloadService.DownloadComplete += OnDownloadComplete; + var progress = new Progress(progress => + { + OnDownloadProgressChanged(this, progress); + }); - await downloadService.DownloadToFileAsync(portableGitUrl, PortableGitDownloadPath); - - downloadService.DownloadProgressChanged -= OnDownloadProgressChanged; - downloadService.DownloadComplete -= OnDownloadComplete; + await downloadService.DownloadToFileAsync(portableGitUrl, PortableGitDownloadPath, progress: progress); + OnDownloadComplete(this, new ProgressReport(progress: 1f)); } await UnzipGit(); diff --git a/StabilityMatrix/MainWindow.xaml b/StabilityMatrix/MainWindow.xaml index 9d3afb6f..822feac4 100644 --- a/StabilityMatrix/MainWindow.xaml +++ b/StabilityMatrix/MainWindow.xaml @@ -1,19 +1,20 @@ + Margin="24,16,0,16" /> + + + + + @@ -89,7 +95,7 @@ - + diff --git a/StabilityMatrix/Models/Api/CivitCommercialUse.cs b/StabilityMatrix/Models/Api/CivitCommercialUse.cs new file mode 100644 index 00000000..1b5a97fb --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitCommercialUse.cs @@ -0,0 +1,13 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CivitCommercialUse +{ + None, + Image, + Rent, + Sell +} diff --git a/StabilityMatrix/Models/Api/CivitCreator.cs b/StabilityMatrix/Models/Api/CivitCreator.cs new file mode 100644 index 00000000..bd37ff5d --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitCreator.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitCreator +{ + [JsonPropertyName("username")] + public string Username { get; set; } + + [JsonPropertyName("image")] + public string? Image { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitFile.cs b/StabilityMatrix/Models/Api/CivitFile.cs new file mode 100644 index 00000000..626f2e30 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitFile.cs @@ -0,0 +1,31 @@ +using System; +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitFile +{ + [JsonPropertyName("sizeKb")] + public double SizeKb { get; set; } + + [JsonPropertyName("pickleScanResult")] + public string PickleScanResult { get; set; } + + [JsonPropertyName("virusScanResult")] + public string VirusScanResult { get; set; } + + [JsonPropertyName("scannedAt")] + public DateTime? ScannedAt { get; set; } + + [JsonPropertyName("metadata")] + public CivitFileMetadata Metadata { get; set; } + + [JsonPropertyName("name")] + public string Name { get; set; } + + [JsonPropertyName("downloadUrl")] + public string DownloadUrl { get; set; } + + [JsonPropertyName("hashes")] + public CivitFileHashes Hashes { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitFileHashes.cs b/StabilityMatrix/Models/Api/CivitFileHashes.cs new file mode 100644 index 00000000..1d0d37e1 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitFileHashes.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitFileHashes +{ + public string? SHA256 { get; set; } + + public string? CRC32 { get; set; } + + public string? BLAKE3 { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitFileMetadata.cs b/StabilityMatrix/Models/Api/CivitFileMetadata.cs new file mode 100644 index 00000000..17fc3174 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitFileMetadata.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitFileMetadata +{ + [JsonPropertyName("fp")] + public CivitModelFpType? Fp { get; set; } + + [JsonPropertyName("size")] + public CivitModelSize? Size { get; set; } + + [JsonPropertyName("format")] + public CivitModelFormat? Format { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitImage.cs b/StabilityMatrix/Models/Api/CivitImage.cs new file mode 100644 index 00000000..efe86599 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitImage.cs @@ -0,0 +1,23 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitImage +{ + [JsonPropertyName("url")] + public string Url { get; set; } + + [JsonPropertyName("nsfw")] + public string Nsfw { get; set; } + + [JsonPropertyName("width")] + public int Width { get; set; } + + [JsonPropertyName("height")] + public int Height { get; set; } + + [JsonPropertyName("hash")] + public string Hash { get; set; } + + // TODO: "meta" ( object? ) +} diff --git a/StabilityMatrix/Models/Api/CivitMetadata.cs b/StabilityMatrix/Models/Api/CivitMetadata.cs new file mode 100644 index 00000000..b4bfdbd0 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitMetadata.cs @@ -0,0 +1,25 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + + +public class CivitMetadata +{ + [JsonPropertyName("totalItems")] + public int TotalItems { get; set; } + + [JsonPropertyName("currentPage")] + public int CurrentPage { get; set; } + + [JsonPropertyName("pageSize")] + public int PageSize { get; set; } + + [JsonPropertyName("totalPages")] + public int TotalPages { get; set; } + + [JsonPropertyName("nextPage")] + public string? NextPage { get; set; } + + [JsonPropertyName("prevPage")] + public string? PrevPage { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitMode.cs b/StabilityMatrix/Models/Api/CivitMode.cs new file mode 100644 index 00000000..b79273c0 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitMode.cs @@ -0,0 +1,10 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CivitMode +{ + Archived, + TakenDown +} diff --git a/StabilityMatrix/Models/Api/CivitModel.cs b/StabilityMatrix/Models/Api/CivitModel.cs new file mode 100644 index 00000000..1058ed98 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModel.cs @@ -0,0 +1,36 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitModel +{ + [JsonPropertyName("id")] + public int Id { get; set; } + + [JsonPropertyName("name")] + public string Name { get; set; } + + [JsonPropertyName("description")] + public string Description { get; set; } + + [JsonPropertyName("type")] + public CivitModelType Type { get; set; } + + [JsonPropertyName("nsfw")] + public bool Nsfw { get; set; } + + [JsonPropertyName("tags")] + public string[] Tags { get; set; } + + [JsonPropertyName("mode")] + public CivitMode? Mode { get; set; } + + [JsonPropertyName("creator")] + public CivitCreator Creator { get; set; } + + [JsonPropertyName("stats")] + public CivitModelStats Stats { get; set; } + + [JsonPropertyName("modelVersions")] + public CivitModelVersion[] ModelVersions { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitModelFormat.cs b/StabilityMatrix/Models/Api/CivitModelFormat.cs new file mode 100644 index 00000000..2ab4435f --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelFormat.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + + +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CivitModelFormat +{ + SafeTensor, + PickleTensor, + Other +} diff --git a/StabilityMatrix/Models/Api/CivitModelFpType.cs b/StabilityMatrix/Models/Api/CivitModelFpType.cs new file mode 100644 index 00000000..acef4452 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelFpType.cs @@ -0,0 +1,13 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + + +[JsonConverter(typeof(JsonStringEnumConverter))] +[SuppressMessage("ReSharper", "InconsistentNaming")] +public enum CivitModelFpType +{ + fp16, + fp32 +} diff --git a/StabilityMatrix/Models/Api/CivitModelSize.cs b/StabilityMatrix/Models/Api/CivitModelSize.cs new file mode 100644 index 00000000..5d78c753 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelSize.cs @@ -0,0 +1,12 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +[JsonConverter(typeof(JsonStringEnumConverter))] +[SuppressMessage("ReSharper", "InconsistentNaming")] +public enum CivitModelSize +{ + full, + pruned, +} diff --git a/StabilityMatrix/Models/Api/CivitModelStats.cs b/StabilityMatrix/Models/Api/CivitModelStats.cs new file mode 100644 index 00000000..8f52a991 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelStats.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitModelStats : CivitStats +{ + [JsonPropertyName("favoriteCount")] + public int FavoriteCount { get; set; } + + [JsonPropertyName("commentCount")] + public int CommentCount { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitModelType.cs b/StabilityMatrix/Models/Api/CivitModelType.cs new file mode 100644 index 00000000..5198d15e --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelType.cs @@ -0,0 +1,28 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using StabilityMatrix.Extensions; + +namespace StabilityMatrix.Models.Api; + +[JsonConverter(typeof(JsonStringEnumConverter))] +[SuppressMessage("ReSharper", "InconsistentNaming")] +public enum CivitModelType +{ + [ConvertTo(SharedFolderType.StableDiffusion)] + Checkpoint, + [ConvertTo(SharedFolderType.TextualInversion)] + TextualInversion, + [ConvertTo(SharedFolderType.Hypernetwork)] + Hypernetwork, + AestheticGradient, + [ConvertTo(SharedFolderType.Lora)] + LORA, + [ConvertTo(SharedFolderType.ControlNet)] + Controlnet, + Poses, + [ConvertTo(SharedFolderType.StableDiffusion)] + Model, + [ConvertTo(SharedFolderType.LyCORIS)] + LoCon +} diff --git a/StabilityMatrix/Models/Api/CivitModelVersion.cs b/StabilityMatrix/Models/Api/CivitModelVersion.cs new file mode 100644 index 00000000..d9cbb15e --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelVersion.cs @@ -0,0 +1,37 @@ +using System; +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitModelVersion +{ + [JsonPropertyName("id")] + public int Id { get; set; } + + [JsonPropertyName("name")] + public string Name { get; set; } + + [JsonPropertyName("description")] + public string Description { get; set; } + + [JsonPropertyName("createdAt")] + public DateTime CreatedAt { get; set; } + + [JsonPropertyName("downloadUrl")] + public string DownloadUrl { get; set; } + + [JsonPropertyName("trainedWords")] + public string[] TrainedWords { get; set; } + + [JsonPropertyName("baseModel")] + public string? BaseModel { get; set; } + + [JsonPropertyName("files")] + public CivitFile[] Files { get; set; } + + [JsonPropertyName("images")] + public CivitImage[] Images { get; set; } + + [JsonPropertyName("stats")] + public CivitModelStats Stats { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitModelsRequest.cs b/StabilityMatrix/Models/Api/CivitModelsRequest.cs new file mode 100644 index 00000000..38b0c91d --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelsRequest.cs @@ -0,0 +1,106 @@ +using System.Text.Json.Serialization; +using Refit; + +namespace StabilityMatrix.Models.Api; + + +public class CivitModelsRequest +{ + /// + /// The number of results to be returned per page. This can be a number between 1 and 200. By default, each page will return 100 results + /// + [AliasAs("limit")] + public int? Limit { get; set; } + + /// + /// The page from which to start fetching models + /// + [AliasAs("page")] + public int? Page { get; set; } + + /// + /// Search query to filter models by name + /// + [AliasAs("query")] + public string? Query { get; set; } + + /// + /// Search query to filter models by tag + /// + [AliasAs("tag")] + public string? Tag { get; set; } + + /// + /// Search query to filter models by user + /// + [AliasAs("username")] + public string? Username { get; set; } + + /// + /// The type of model you want to filter with. If none is specified, it will return all types + /// + [AliasAs("types")] + public CivitModelType[]? Types { get; set; } + + /// + /// The order in which you wish to sort the results + /// + [AliasAs("sort")] + public CivitSortMode? Sort { get; set; } + + /// + /// The time frame in which the models will be sorted + /// + [AliasAs("period")] + public CivitPeriod? Period { get; set; } + + /// + /// The rating you wish to filter the models with. If none is specified, it will return models with any rating + /// + [AliasAs("rating")] + public int? Rating { get; set; } + + /// + /// Filter to models that require or don't require crediting the creator + /// Requires Authentication + /// + [AliasAs("favorites")] + public bool? Favorites { get; set; } + + /// + /// Filter to hidden models of the authenticated user + /// Requires Authentication + /// + [AliasAs("hidden")] + public bool? Hidden { get; set; } + + /// + /// Only include the primary file for each model (This will use your preferred format options if you use an API token or session cookie) + /// + [AliasAs("primaryFileOnly")] + public bool? PrimaryFileOnly { get; set; } + + /// + /// Filter to models that allow or don't allow creating derivatives + /// + [AliasAs("allowDerivatives")] + public bool? AllowDerivatives { get; set; } + + /// + /// Filter to models that allow or don't allow derivatives to have a different license + /// + [AliasAs("allowDifferentLicenses")] + public bool? AllowDifferentLicenses { get; set; } + + /// + /// Filter to models based on their commercial permissions + /// + [AliasAs("allowCommercialUse")] + public CivitCommercialUse? AllowCommercialUse { get; set; } + + /// + /// If false, will return safer images and hide models that don't have safe images + /// + [AliasAs("nsfw")] + public string? Nsfw { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitModelsResponse.cs b/StabilityMatrix/Models/Api/CivitModelsResponse.cs new file mode 100644 index 00000000..62b5c8bc --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitModelsResponse.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitModelsResponse +{ + [JsonPropertyName("items")] + public CivitModel[]? Items { get; set; } + + [JsonPropertyName("metadata")] + public CivitMetadata? Metadata { get; set; } +} diff --git a/StabilityMatrix/Models/Api/CivitPeriod.cs b/StabilityMatrix/Models/Api/CivitPeriod.cs new file mode 100644 index 00000000..222629bb --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitPeriod.cs @@ -0,0 +1,13 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CivitPeriod +{ + AllTime, + Year, + Month, + Week, + Day +} diff --git a/StabilityMatrix/Models/Api/CivitSortMode.cs b/StabilityMatrix/Models/Api/CivitSortMode.cs new file mode 100644 index 00000000..4ea22673 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitSortMode.cs @@ -0,0 +1,16 @@ +using System.Diagnostics.CodeAnalysis; +using System.Runtime.Serialization; +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CivitSortMode +{ + [EnumMember(Value = "Highest Rated")] + HighestRated, + [EnumMember(Value = "Most Downloaded")] + MostDownloaded, + [EnumMember(Value = "Newest")] + Newest +} diff --git a/StabilityMatrix/Models/Api/CivitStats.cs b/StabilityMatrix/Models/Api/CivitStats.cs new file mode 100644 index 00000000..69c37420 --- /dev/null +++ b/StabilityMatrix/Models/Api/CivitStats.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Models.Api; + +public class CivitStats +{ + [JsonPropertyName("downloadCount")] + public int DownloadCount { get; set; } + + [JsonPropertyName("ratingCount")] + public int RatingCount { get; set; } + + [JsonPropertyName("rating")] + public double Rating { get; set; } +} diff --git a/StabilityMatrix/Models/CheckpointFile.cs b/StabilityMatrix/Models/CheckpointFile.cs index 7c436534..ad83f1ad 100644 --- a/StabilityMatrix/Models/CheckpointFile.cs +++ b/StabilityMatrix/Models/CheckpointFile.cs @@ -1,41 +1,93 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using System.IO; using System.Linq; +using System.Threading.Tasks; using System.Windows.Media.Imaging; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using NLog; namespace StabilityMatrix.Models; -public class CheckpointFile +public partial class CheckpointFile : ObservableObject { + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + // Event for when this file is deleted + public event EventHandler? Deleted; + /// /// Absolute path to the checkpoint file. /// public string FilePath { get; init; } = string.Empty; - + /// /// Custom title for UI. /// - public string Title { get; init; } = string.Empty; - - public string? PreviewImagePath { get; set; } + [ObservableProperty] private string title = string.Empty; + public string? PreviewImagePath { get; set; } public BitmapImage? PreviewImage { get; set; } - public bool IsPreviewImageLoaded => PreviewImage != null; + + [ObservableProperty] private ConnectedModelInfo? connectedModel; + public bool IsConnectedModel => ConnectedModel != null; + + [ObservableProperty] private bool isLoading; public string FileName => Path.GetFileName(FilePath); - private static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt" }; + private static readonly string[] SupportedCheckpointExtensions = { ".safetensors", ".pt", ".ckpt", ".pth" }; private static readonly string[] SupportedImageExtensions = { ".png", ".jpg", ".jpeg" }; + partial void OnConnectedModelChanged(ConnectedModelInfo? value) + { + if (value == null) return; + // Update title, first check user defined, then connected model name + Title = value.UserTitle ?? value.ModelName; + } + + [RelayCommand] + private async Task DeleteAsync() + { + if (File.Exists(FilePath)) + { + // Start progress ring + IsLoading = true; + var timer = Stopwatch.StartNew(); + try + { + await Task.Run(() => File.Delete(FilePath)); + if (PreviewImagePath != null && File.Exists(PreviewImagePath)) + { + await Task.Run(() => File.Delete(PreviewImagePath)); + } + // If it was too fast, wait a bit to show progress ring + var targetDelay = new Random().Next(200, 500); + var elapsed = timer.ElapsedMilliseconds; + if (elapsed < targetDelay) + { + await Task.Delay(targetDelay - (int) elapsed); + } + } + catch (IOException e) + { + Logger.Error(e, $"Failed to delete checkpoint file: {FilePath}"); + IsLoading = false; + return; // Don't delete from collection + } + } + Deleted?.Invoke(this, this); + } /// /// Indexes directory and yields all checkpoint files. /// First we match all files with supported extensions. /// If found, we also look for - /// - {filename}.preview.{image-extensions} + /// - {filename}.preview.{image-extensions} (preview image) + /// - {filename}.cm-info.json (connected model info) /// public static IEnumerable FromDirectoryIndex(string directory, SearchOption searchOption = SearchOption.TopDirectoryOnly) { @@ -52,6 +104,22 @@ public class CheckpointFile Title = Path.GetFileNameWithoutExtension(file), FilePath = Path.Combine(directory, file), }; + + // Check for connected model info + var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(file); + var cmInfoPath = $"{fileNameWithoutExtension}.cm-info.json"; + if (files.ContainsKey(cmInfoPath)) + { + try + { + var jsonData = File.ReadAllText(Path.Combine(directory, cmInfoPath)); + checkpointFile.ConnectedModel = ConnectedModelInfo.FromJson(jsonData); + } + catch (IOException e) + { + Debug.WriteLine($"Failed to parse {cmInfoPath}: {e}"); + } + } // Check for preview image var previewImage = SupportedImageExtensions.Select(ext => $"{checkpointFile.FileName}.preview.{ext}").FirstOrDefault(files.ContainsKey); diff --git a/StabilityMatrix/Models/CheckpointFolder.cs b/StabilityMatrix/Models/CheckpointFolder.cs index 9b885b83..705c808c 100644 --- a/StabilityMatrix/Models/CheckpointFolder.cs +++ b/StabilityMatrix/Models/CheckpointFolder.cs @@ -1,10 +1,19 @@ using System; +using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Collections.Specialized; +using System.IO; +using System.Linq; using System.Threading.Tasks; +using System.Windows; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using StabilityMatrix.Helper; +using StabilityMatrix.ViewModels; namespace StabilityMatrix.Models; -public class CheckpointFolder +public partial class CheckpointFolder : ObservableObject { /// /// Absolute path to the folder. @@ -16,7 +25,102 @@ public class CheckpointFolder /// public string Title { get; init; } = string.Empty; + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsDragBlurEnabled))] + private bool isCurrentDragTarget; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsDragBlurEnabled))] + private bool isImportInProgress; + + public bool IsDragBlurEnabled => IsCurrentDragTarget || IsImportInProgress; + + public ProgressViewModel Progress { get; } = new(); + public ObservableCollection CheckpointFiles { get; set; } = new(); + + public RelayCommand OnPreviewDragEnterCommand => new(() => IsCurrentDragTarget = true); + public RelayCommand OnPreviewDragLeaveCommand => new(() => IsCurrentDragTarget = false); + + public CheckpointFolder() + { + CheckpointFiles.CollectionChanged += OnCheckpointFilesChanged; + } + + // On collection changes + private void OnCheckpointFilesChanged(object? sender, NotifyCollectionChangedEventArgs e) + { + if (e.NewItems == null) return; + // On new added items, add event handler for deletion + foreach (CheckpointFile item in e.NewItems) + { + item.Deleted += OnCheckpointFileDelete; + } + } + + /// + /// Handler for CheckpointFile requesting to be deleted from the collection. + /// + /// + /// + private void OnCheckpointFileDelete(object? sender, CheckpointFile file) + { + Application.Current.Dispatcher.Invoke(() => CheckpointFiles.Remove(file)); + } + + [RelayCommand] + private async Task OnPreviewDropAsync(DragEventArgs e) + { + IsImportInProgress = true; + IsCurrentDragTarget = false; + + var files = e.Data.GetData(DataFormats.FileDrop) as string[]; + if (files == null || files.Length < 1) + { + IsImportInProgress = false; + return; + } + + await ImportFilesAsync(files); + } + + /// + /// Imports files to the folder. Reports progress to instance properties. + /// + public async Task ImportFilesAsync(IEnumerable files) + { + Progress.IsIndeterminate = true; + Progress.IsProgressVisible = true; + var copyPaths = files.ToDictionary(k => k, v => Path.Combine(DirectoryPath, Path.GetFileName(v))); + + var progress = new Progress(report => + { + Progress.IsIndeterminate = false; + Progress.Value = report.Percentage; + // For multiple files, add count + Progress.Text = copyPaths.Count > 1 ? $"Importing {report.Title} ({report.Message})" : $"Importing {report.Title}"; + }); + + await FileTransfers.CopyFiles(copyPaths, progress); + Progress.Value = 100; + Progress.Text = "Import complete"; + await IndexAsync(); + DelayedClearProgress(TimeSpan.FromSeconds(1)); + } + + /// + /// Clears progress after a delay. + /// + private void DelayedClearProgress(TimeSpan delay) + { + Task.Delay(delay).ContinueWith(_ => + { + IsImportInProgress = false; + Progress.IsProgressVisible = false; + Progress.Value = 0; + Progress.Text = string.Empty; + }); + } /// /// Indexes the folder for checkpoint files. diff --git a/StabilityMatrix/Models/ConnectedModelInfo.cs b/StabilityMatrix/Models/ConnectedModelInfo.cs new file mode 100644 index 00000000..245feae2 --- /dev/null +++ b/StabilityMatrix/Models/ConnectedModelInfo.cs @@ -0,0 +1,49 @@ +using System; +using System.Text.Json; +using StabilityMatrix.Extensions; +using StabilityMatrix.Models.Api; + +namespace StabilityMatrix.Models; + +public class ConnectedModelInfo +{ + public int ModelId { get; set; } + public string ModelName { get; set; } + public string ModelDescription { get; set; } + public bool Nsfw { get; set; } + public string[] Tags { get; set; } + public CivitModelType ModelType { get; set; } + public int VersionId { get; set; } + public string VersionName { get; set; } + public string VersionDescription { get; set; } + public string? BaseModel { get; set; } + public CivitFileMetadata FileMetadata { get; set; } + public DateTime ImportedAt { get; set; } + public CivitFileHashes Hashes { get; set; } + + // User settings + public string? UserTitle { get; set; } + public string? ThumbnailImageUrl { get; set; } + + public ConnectedModelInfo(CivitModel civitModel, CivitModelVersion civitModelVersion, CivitFile civitFile, DateTime importedAt) + { + ModelId = civitModel.Id; + ModelName = civitModel.Name; + ModelDescription = civitModel.Description; + Nsfw = civitModel.Nsfw; + Tags = civitModel.Tags; + ModelType = civitModel.Type; + VersionId = civitModelVersion.Id; + VersionName = civitModelVersion.Name; + VersionDescription = civitModelVersion.Description; + ImportedAt = importedAt; + BaseModel = civitModelVersion.BaseModel; + FileMetadata = civitFile.Metadata; + Hashes = civitFile.Hashes; + } + + public static ConnectedModelInfo? FromJson(string json) + { + return JsonSerializer.Deserialize(json); + } +} diff --git a/StabilityMatrix/Models/ISharedFolders.cs b/StabilityMatrix/Models/ISharedFolders.cs index 3f8cb2a7..531bbccc 100644 --- a/StabilityMatrix/Models/ISharedFolders.cs +++ b/StabilityMatrix/Models/ISharedFolders.cs @@ -4,7 +4,5 @@ namespace StabilityMatrix.Models; public interface ISharedFolders { - string SharedFoldersPath { get; } - string SharedFolderTypeToName(SharedFolderType folderType); void SetupLinksForPackage(BasePackage basePackage, string installPath); } diff --git a/StabilityMatrix/Models/Packages/A3WebUI.cs b/StabilityMatrix/Models/Packages/A3WebUI.cs index 5e39583d..a463b036 100644 --- a/StabilityMatrix/Models/Packages/A3WebUI.cs +++ b/StabilityMatrix/Models/Packages/A3WebUI.cs @@ -7,6 +7,7 @@ using System.Text.RegularExpressions; using System.Threading.Tasks; using StabilityMatrix.Helper; using StabilityMatrix.Helper.Cache; +using StabilityMatrix.Python; using StabilityMatrix.Services; namespace StabilityMatrix.Models.Packages; @@ -40,6 +41,9 @@ public class A3WebUI : BaseGitPackage [SharedFolderType.VAE] = "models/VAE", [SharedFolderType.DeepDanbooru] = "models/deepbooru", [SharedFolderType.Karlo] = "models/karlo", + [SharedFolderType.TextualInversion] = "embeddings", + [SharedFolderType.Hypernetwork] = "models/hypernetworks", + [SharedFolderType.ControlNet] = "models/ControlNet" }; public override List LaunchOptions => new() @@ -97,7 +101,7 @@ public class A3WebUI : BaseGitPackage var allReleases = await GetAllReleases(); return allReleases.Select(r => new PackageVersion {TagName = r.TagName!, ReleaseNotesMarkdown = r.Body}); } - else // branch mode1 + else // branch mode { var allBranches = await GetAllBranches(); return allBranches.Select(b => new PackageVersion @@ -108,6 +112,41 @@ public class A3WebUI : BaseGitPackage } } + public override async Task InstallPackage(bool isUpdate = false) + { + UnzipPackage(isUpdate); + OnInstallProgressChanged(-1); // Indeterminate progress bar + + Logger.Debug("Setting up venv"); + await SetupVenv(InstallLocation); + var venvRunner = new PyVenvRunner(Path.Combine(InstallLocation, "venv")); + + void HandleConsoleOutput(string? s) + { + Debug.WriteLine($"venv stdout: {s}"); + OnConsoleOutput(s); + } + + // install prereqs + await venvRunner.PipInstall(venvRunner.GetTorchInstallCommand(), InstallLocation, HandleConsoleOutput); + if (HardwareHelper.HasNvidiaGpu()) + { + await venvRunner.PipInstall("xformers", InstallLocation, HandleConsoleOutput); + } + + await venvRunner.PipInstall("-r requirements.txt", InstallLocation, HandleConsoleOutput); + + Logger.Debug("Finished installing requirements"); + if (isUpdate) + { + OnUpdateComplete("Update complete"); + } + else + { + OnInstallComplete("Install complete"); + } + } + public override async Task RunPackage(string installedPackagePath, string arguments) { await SetupVenv(installedPackagePath); diff --git a/StabilityMatrix/Models/Packages/BaseGitPackage.cs b/StabilityMatrix/Models/Packages/BaseGitPackage.cs index 601b7aac..fe3e21b1 100644 --- a/StabilityMatrix/Models/Packages/BaseGitPackage.cs +++ b/StabilityMatrix/Models/Packages/BaseGitPackage.cs @@ -108,19 +108,13 @@ public abstract class BaseGitPackage : BasePackage Directory.CreateDirectory(DownloadLocation.Replace($"{Name}.zip", "")); } - void DownloadProgressHandler(object? _, ProgressReport progress) => + var progress = new Progress(progress => + { DownloadServiceOnDownloadProgressChanged(progress, isUpdate); - - void DownloadFinishedHandler(object? _, ProgressReport downloadLocation) => - DownloadServiceOnDownloadFinished(downloadLocation, isUpdate); - - DownloadService.DownloadProgressChanged += DownloadProgressHandler; - DownloadService.DownloadComplete += DownloadFinishedHandler; - - await DownloadService.DownloadToFileAsync(downloadUrl, DownloadLocation); + }); - DownloadService.DownloadProgressChanged -= DownloadProgressHandler; - DownloadService.DownloadComplete -= DownloadFinishedHandler; + await DownloadService.DownloadToFileAsync(downloadUrl, DownloadLocation, progress: progress); + DownloadServiceOnDownloadFinished(new ProgressReport(100, "Download Complete"), isUpdate); return version; } diff --git a/StabilityMatrix/Models/SharedFolderType.cs b/StabilityMatrix/Models/SharedFolderType.cs index 0478b6ae..c9c3c51c 100644 --- a/StabilityMatrix/Models/SharedFolderType.cs +++ b/StabilityMatrix/Models/SharedFolderType.cs @@ -16,4 +16,7 @@ public enum SharedFolderType ApproxVAE, Karlo, DeepDanbooru, + TextualInversion, + Hypernetwork, + ControlNet } diff --git a/StabilityMatrix/Models/SharedFolders.cs b/StabilityMatrix/Models/SharedFolders.cs index 1879a1ce..54be07e0 100644 --- a/StabilityMatrix/Models/SharedFolders.cs +++ b/StabilityMatrix/Models/SharedFolders.cs @@ -2,6 +2,7 @@ using System.IO; using NCode.ReparsePoints; using NLog; +using StabilityMatrix.Extensions; using StabilityMatrix.Models.Packages; namespace StabilityMatrix.Models; @@ -10,11 +11,11 @@ public class SharedFolders : ISharedFolders { private const string SharedFoldersName = "Models"; private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - public string SharedFoldersPath { get; } = + public static string SharedFoldersPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "StabilityMatrix", SharedFoldersName); - public string SharedFolderTypeToName(SharedFolderType folderType) + public static string SharedFolderTypeToName(SharedFolderType folderType) { return Enum.GetName(typeof(SharedFolderType), folderType)!; } @@ -30,7 +31,7 @@ public class SharedFolders : ISharedFolders var provider = ReparsePointFactory.Provider; foreach (var (folderType, relativePath) in sharedFolders) { - var source = Path.GetFullPath(Path.Combine(SharedFoldersPath, SharedFolderTypeToName(folderType))); + var source = Path.GetFullPath(Path.Combine(SharedFoldersPath, folderType.GetStringValue())); var destination = Path.GetFullPath(Path.Combine(installPath, relativePath)); // Create source folder if it doesn't exist if (!Directory.Exists(source)) diff --git a/StabilityMatrix/Services/DownloadService.cs b/StabilityMatrix/Services/DownloadService.cs index 5dd998d3..395d8f53 100644 --- a/StabilityMatrix/Services/DownloadService.cs +++ b/StabilityMatrix/Services/DownloadService.cs @@ -5,6 +5,7 @@ using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using Polly.Contrib.WaitAndRetry; using StabilityMatrix.Models; namespace StabilityMatrix.Services; @@ -19,11 +20,9 @@ public class DownloadService : IDownloadService this.logger = logger; this.httpClientFactory = httpClientFactory; } - - public event EventHandler? DownloadProgressChanged; - public event EventHandler? DownloadComplete; - - public async Task DownloadToFileAsync(string downloadUrl, string downloadLocation, int bufferSize = ushort.MaxValue) + + public async Task DownloadToFileAsync(string downloadUrl, string downloadLocation, int bufferSize = ushort.MaxValue, + IProgress? progress = null) { using var client = httpClientFactory.CreateClient(); client.Timeout = TimeSpan.FromMinutes(5); @@ -31,26 +30,28 @@ public class DownloadService : IDownloadService await using var file = new FileStream(downloadLocation, FileMode.Create, FileAccess.Write, FileShare.None); long contentLength = 0; - var retryCount = 0; - + var response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); contentLength = response.Content.Headers.ContentLength ?? 0; - while (contentLength == 0 && retryCount++ < 5) + var delays = Backoff.DecorrelatedJitterBackoffV2( + TimeSpan.FromMilliseconds(50), retryCount: 3); + + foreach (var delay in delays) { + if (contentLength > 0) break; logger.LogDebug("Retrying get-headers for content-length"); - Thread.Sleep(50); + await Task.Delay(delay); response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); contentLength = response.Content.Headers.ContentLength ?? 0; } - var isIndeterminate = contentLength == 0; await using var stream = await response.Content.ReadAsStreamAsync(); - var totalBytesRead = 0; + var totalBytesRead = 0L; + var buffer = new byte[bufferSize]; while (true) { - var buffer = new byte[bufferSize]; var bytesRead = await stream.ReadAsync(buffer); if (bytesRead == 0) break; await file.WriteAsync(buffer.AsMemory(0, bytesRead)); @@ -59,22 +60,15 @@ public class DownloadService : IDownloadService if (isIndeterminate) { - OnDownloadProgressChanged(-1); + progress?.Report(new ProgressReport(-1, isIndeterminate: true)); } else { - var progress = totalBytesRead / (double) contentLength; - OnDownloadProgressChanged(progress); + progress?.Report(new ProgressReport(current: Convert.ToUInt64(totalBytesRead), + total: Convert.ToUInt64(contentLength))); } } await file.FlushAsync(); - OnDownloadComplete(downloadLocation); } - - private void OnDownloadProgressChanged(double progress) => - DownloadProgressChanged?.Invoke(this, new ProgressReport(progress)); - - private void OnDownloadComplete(string path) => - DownloadComplete?.Invoke(this, new ProgressReport(progress: 100f, message: path)); } diff --git a/StabilityMatrix/Services/IDownloadService.cs b/StabilityMatrix/Services/IDownloadService.cs index 889dc09c..a1e529bf 100644 --- a/StabilityMatrix/Services/IDownloadService.cs +++ b/StabilityMatrix/Services/IDownloadService.cs @@ -6,7 +6,6 @@ namespace StabilityMatrix.Services; public interface IDownloadService { - event EventHandler? DownloadProgressChanged; - event EventHandler? DownloadComplete; - Task DownloadToFileAsync(string downloadUrl, string downloadLocation, int bufferSize = ushort.MaxValue); + Task DownloadToFileAsync(string downloadUrl, string downloadLocation, int bufferSize = ushort.MaxValue, + IProgress? progress = null); } diff --git a/StabilityMatrix/SettingsPage.xaml b/StabilityMatrix/SettingsPage.xaml index 0a03fae4..3e5523c5 100644 --- a/StabilityMatrix/SettingsPage.xaml +++ b/StabilityMatrix/SettingsPage.xaml @@ -103,6 +103,10 @@ Command="{Binding PingWebApiCommand}" Content="Ping Web API" Margin="8" /> +