From a0f1c0a8f34bc42df75e7acba66bb6070ed6ad93 Mon Sep 17 00:00:00 2001 From: JT Date: Sun, 24 Dec 2023 23:55:16 -0800 Subject: [PATCH 1/2] load pngs fastly & fix metadata scan crash --- CHANGELOG.md | 2 + .../FallbackRamCachedWebImageLoader.cs | 7 +- .../ViewModels/OutputsPageViewModel.cs | 51 ++---- StabilityMatrix.Core/Helper/ImageMetadata.cs | 48 ++++++ .../Services/MetadataImportService.cs | 159 +++++++++++------- 5 files changed, 168 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57aed4b9..820add42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2 ## v2.7.5 ### Fixed - Fixed Python Packages manager crash when pip list returns warnings in json +- Fixed slowdown when loading PNGs with large amounts of metadata +- Fixed crash when scanning directories for missing metadata ## v2.7.4 ### Changed diff --git a/StabilityMatrix.Avalonia/FallbackRamCachedWebImageLoader.cs b/StabilityMatrix.Avalonia/FallbackRamCachedWebImageLoader.cs index f8ae8235..9c001a92 100644 --- a/StabilityMatrix.Avalonia/FallbackRamCachedWebImageLoader.cs +++ b/StabilityMatrix.Avalonia/FallbackRamCachedWebImageLoader.cs @@ -7,6 +7,7 @@ using AsyncAwaitBestPractices; using AsyncImageLoader.Loaders; using Avalonia.Media.Imaging; using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; namespace StabilityMatrix.Avalonia; @@ -42,7 +43,11 @@ public class FallbackRamCachedWebImageLoader : RamCachedWebImageLoader { try { - return new Bitmap(url); + if (!url.EndsWith("png", StringComparison.OrdinalIgnoreCase)) + return new Bitmap(url); + + using var stream = ImageMetadata.BuildImageWithoutMetadata(url); + return stream == null ? new Bitmap(url) : new Bitmap(stream); } catch (Exception e) { diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs index c1c66a72..03ca531c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs @@ -54,11 +54,9 @@ public partial class OutputsPageViewModel : PageViewModelBase private readonly ILogger logger; public override string Title => Resources.Label_OutputsPageTitle; - public override IconSource IconSource => - new SymbolIconSource { Symbol = Symbol.Grid, IsFilled = true }; + public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.Grid, IsFilled = true }; - public SourceCache OutputsCache { get; } = - new(file => file.AbsolutePath); + public SourceCache OutputsCache { get; } = new(file => file.AbsolutePath); public IObservableCollection Outputs { get; set; } = new ObservableCollectionExtended(); @@ -88,8 +86,7 @@ public partial class OutputsPageViewModel : PageViewModelBase [ObservableProperty] private bool isConsolidating; - public bool CanShowOutputTypes => - SelectedCategory?.Name?.Equals("Shared Output Folder") ?? false; + public bool CanShowOutputTypes => SelectedCategory?.Name?.Equals("Shared Output Folder") ?? false; public string NumImagesSelected => NumItemsSelected == 1 @@ -163,10 +160,7 @@ public partial class OutputsPageViewModel : PageViewModelBase GetOutputs(path); } - partial void OnSelectedCategoryChanged( - PackageOutputCategory? oldValue, - PackageOutputCategory? newValue - ) + partial void OnSelectedCategoryChanged(PackageOutputCategory? oldValue, PackageOutputCategory? newValue) { if (oldValue == newValue || newValue == null) return; @@ -223,8 +217,8 @@ public partial class OutputsPageViewModel : PageViewModelBase ) .Subscribe(ctx => { - Dispatcher.UIThread - .InvokeAsync(async () => + Dispatcher + .UIThread.InvokeAsync(async () => { var sender = (ImageViewerViewModel)ctx.Sender!; var newIndex = currentIndex + (ctx.EventArgs.IsNext ? 1 : -1); @@ -430,9 +424,7 @@ public partial class OutputsPageViewModel : PageViewModelBase Directory.CreateDirectory(settingsManager.ConsolidatedImagesDirectory); - foreach ( - var category in stackPanel.Children.OfType().Where(c => c.IsChecked == true) - ) + foreach (var category in stackPanel.Children.OfType().Where(c => c.IsChecked == true)) { if ( string.IsNullOrWhiteSpace(category.Tag?.ToString()) @@ -442,13 +434,7 @@ public partial class OutputsPageViewModel : PageViewModelBase var directory = category.Tag.ToString(); - foreach ( - var path in Directory.EnumerateFiles( - directory, - "*.png", - SearchOption.AllDirectories - ) - ) + foreach (var path in Directory.EnumerateFiles(directory, "*.png", SearchOption.AllDirectories)) { try { @@ -534,15 +520,11 @@ public partial class OutputsPageViewModel : PageViewModelBase var previouslySelectedCategory = SelectedCategory; - var packageCategories = settingsManager.Settings.InstalledPackages - .Where(x => !x.UseSharedOutputFolder) + var packageCategories = settingsManager + .Settings.InstalledPackages.Where(x => !x.UseSharedOutputFolder) .Select(packageFactory.GetPackagePair) .WhereNotNull() - .Where( - p => - p.BasePackage.SharedOutputFolders != null - && p.BasePackage.SharedOutputFolders.Any() - ) + .Where(p => p.BasePackage.SharedOutputFolders != null && p.BasePackage.SharedOutputFolders.Any()) .Select( pair => new PackageOutputCategory @@ -567,16 +549,11 @@ public partial class OutputsPageViewModel : PageViewModelBase packageCategories.Insert( 1, - new PackageOutputCategory - { - Path = settingsManager.ImagesInferenceDirectory, - Name = "Inference" - } + new PackageOutputCategory { Path = settingsManager.ImagesInferenceDirectory, Name = "Inference" } ); Categories = new ObservableCollection(packageCategories); - SelectedCategory = - Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name) - ?? Categories.First(); + selectedCategory = + Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name) ?? Categories.First(); } } diff --git a/StabilityMatrix.Core/Helper/ImageMetadata.cs b/StabilityMatrix.Core/Helper/ImageMetadata.cs index a10dae53..439242bf 100644 --- a/StabilityMatrix.Core/Helper/ImageMetadata.cs +++ b/StabilityMatrix.Core/Helper/ImageMetadata.cs @@ -179,4 +179,52 @@ public class ImageMetadata return string.Empty; } + + public static MemoryStream? BuildImageWithoutMetadata(FilePath imagePath) + { + using var byteStream = new BinaryReader(File.OpenRead(imagePath)); + byteStream.BaseStream.Position = 0; + + if (!byteStream.ReadBytes(8).SequenceEqual(PngHeader)) + { + return null; + } + + var memoryStream = new MemoryStream(); + memoryStream.Write(PngHeader); + + // add the IHDR chunk + var ihdrStuff = byteStream.ReadBytes(25); + memoryStream.Write(ihdrStuff); + + // find IDATs + while (byteStream.BaseStream.Position < byteStream.BaseStream.Length - 4) + { + var chunkSizeBytes = byteStream.ReadBytes(4); + var chunkSize = BitConverter.ToInt32(chunkSizeBytes.Reverse().ToArray()); + var chunkTypeBytes = byteStream.ReadBytes(4); + var chunkType = Encoding.UTF8.GetString(chunkTypeBytes); + + if (chunkType != Encoding.UTF8.GetString(Idat)) + { + // skip chunk data + byteStream.BaseStream.Position += chunkSize; + // skip crc + byteStream.BaseStream.Position += 4; + continue; + } + + memoryStream.Write(chunkSizeBytes); + memoryStream.Write(chunkTypeBytes); + var idatBytes = byteStream.ReadBytes(chunkSize); + memoryStream.Write(idatBytes); + var crcBytes = byteStream.ReadBytes(4); + memoryStream.Write(crcBytes); + } + + // Add IEND chunk + memoryStream.Write([0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82]); + memoryStream.Position = 0; + return memoryStream; + } } diff --git a/StabilityMatrix.Core/Services/MetadataImportService.cs b/StabilityMatrix.Core/Services/MetadataImportService.cs index 8a333dd0..5a88fb83 100644 --- a/StabilityMatrix.Core/Services/MetadataImportService.cs +++ b/StabilityMatrix.Core/Services/MetadataImportService.cs @@ -1,5 +1,4 @@ -using System.Diagnostics; -using System.Text.Json; +using System.Text.Json; using Microsoft.Extensions.Logging; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; @@ -18,7 +17,10 @@ public class MetadataImportService( ModelFinder modelFinder ) : IMetadataImportService { - public async Task ScanDirectoryForMissingInfo(DirectoryPath directory, IProgress? progress = null) + public async Task ScanDirectoryForMissingInfo( + DirectoryPath directory, + IProgress? progress = null + ) { progress?.Report(new ProgressReport(-1f, "Scanning directory...", isIndeterminate: true)); @@ -54,7 +56,9 @@ public class MetadataImportService( } var fileNameWithoutExtension = checkpointFilePath.NameWithoutExtension; - var cmInfoPath = checkpointFilePath.Directory?.JoinFile($"{fileNameWithoutExtension}.cm-info.json"); + var cmInfoPath = checkpointFilePath.Directory?.JoinFile( + $"{fileNameWithoutExtension}.cm-info.json" + ); var cmInfoExists = File.Exists(cmInfoPath); if (cmInfoExists) continue; @@ -70,43 +74,57 @@ public class MetadataImportService( ); }); - var blake3 = await GetBlake3Hash(cmInfoPath, checkpointFilePath, hashProgress).ConfigureAwait(false); - if (string.IsNullOrWhiteSpace(blake3)) + try { - logger.LogWarning($"Blake3 hash was null for {checkpointFilePath}"); - scanned++; - continue; - } - - var modelInfo = await modelFinder.RemoteFindModel(blake3).ConfigureAwait(false); - if (modelInfo == null) - { - logger.LogWarning($"Could not find model for {blake3}"); - scanned++; - continue; - } + var blake3 = await GetBlake3Hash(cmInfoPath, checkpointFilePath, hashProgress) + .ConfigureAwait(false); + if (string.IsNullOrWhiteSpace(blake3)) + { + logger.LogWarning($"Blake3 hash was null for {checkpointFilePath}"); + scanned++; + continue; + } + + var modelInfo = await modelFinder.RemoteFindModel(blake3).ConfigureAwait(false); + if (modelInfo == null) + { + logger.LogWarning($"Could not find model for {blake3}"); + scanned++; + continue; + } + + var (model, modelVersion, modelFile) = modelInfo.Value; + + var updatedCmInfo = new ConnectedModelInfo( + model, + modelVersion, + modelFile, + DateTimeOffset.UtcNow + ); + await updatedCmInfo + .SaveJsonToDirectory(checkpointFilePath.Directory, fileNameWithoutExtension) + .ConfigureAwait(false); - var (model, modelVersion, modelFile) = modelInfo.Value; + var image = modelVersion.Images?.FirstOrDefault( + img => LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) + ); + if (image == null) + { + scanned++; + success++; + continue; + } - var updatedCmInfo = new ConnectedModelInfo(model, modelVersion, modelFile, DateTimeOffset.UtcNow); - await updatedCmInfo - .SaveJsonToDirectory(checkpointFilePath.Directory, fileNameWithoutExtension) - .ConfigureAwait(false); + await DownloadImage(image, checkpointFilePath, progress).ConfigureAwait(false); - var image = modelVersion - .Images - ?.FirstOrDefault(img => LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url))); - if (image == null) - { scanned++; success++; - continue; } - - await DownloadImage(image, checkpointFilePath, progress).ConfigureAwait(false); - - scanned++; - success++; + catch (Exception e) + { + logger.LogError(e, "Error while scanning {checkpointFilePath}", checkpointFilePath); + scanned++; + } } progress?.Report( @@ -124,7 +142,10 @@ public class MetadataImportService( && !File.Exists(file.Directory?.JoinFile($"{file.NameWithoutExtension}.cm-info.json")); } - public async Task UpdateExistingMetadata(DirectoryPath directory, IProgress? progress = null) + public async Task UpdateExistingMetadata( + DirectoryPath directory, + IProgress? progress = null + ) { progress?.Report(new ProgressReport(-1f, "Scanning directory...", isIndeterminate: true)); @@ -151,33 +172,47 @@ public class MetadataImportService( ) ); - var hash = cmInfoValue.Hashes.BLAKE3; - if (string.IsNullOrWhiteSpace(hash)) - continue; - - var modelInfo = await modelFinder.RemoteFindModel(hash).ConfigureAwait(false); - if (modelInfo == null) + try { - logger.LogWarning($"Could not find model for {hash}"); - continue; - } - - var (model, modelVersion, modelFile) = modelInfo.Value; - - var updatedCmInfo = new ConnectedModelInfo(model, modelVersion, modelFile, DateTimeOffset.UtcNow); + var hash = cmInfoValue.Hashes.BLAKE3; + if (string.IsNullOrWhiteSpace(hash)) + continue; + + var modelInfo = await modelFinder.RemoteFindModel(hash).ConfigureAwait(false); + if (modelInfo == null) + { + logger.LogWarning($"Could not find model for {hash}"); + continue; + } + + var (model, modelVersion, modelFile) = modelInfo.Value; + + var updatedCmInfo = new ConnectedModelInfo( + model, + modelVersion, + modelFile, + DateTimeOffset.UtcNow + ); - var nameWithoutCmInfo = filePath.NameWithoutExtension.Replace(".cm-info", string.Empty); - await updatedCmInfo.SaveJsonToDirectory(filePath.Directory, nameWithoutCmInfo).ConfigureAwait(false); + var nameWithoutCmInfo = filePath.NameWithoutExtension.Replace(".cm-info", string.Empty); + await updatedCmInfo + .SaveJsonToDirectory(filePath.Directory, nameWithoutCmInfo) + .ConfigureAwait(false); - var image = modelVersion - .Images - ?.FirstOrDefault(img => LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url))); - if (image == null) - continue; + var image = modelVersion.Images?.FirstOrDefault( + img => LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) + ); + if (image == null) + continue; - await DownloadImage(image, filePath, progress).ConfigureAwait(false); + await DownloadImage(image, filePath, progress).ConfigureAwait(false); - success++; + success++; + } + catch (Exception e) + { + logger.LogError(e, "Error while updating {filePath}", filePath); + } } } @@ -223,11 +258,13 @@ public class MetadataImportService( var (model, modelVersion, modelFile) = modelInfo.Value; var updatedCmInfo = new ConnectedModelInfo(model, modelVersion, modelFile, DateTimeOffset.UtcNow); - await updatedCmInfo.SaveJsonToDirectory(filePath.Directory, fileNameWithoutExtension).ConfigureAwait(false); + await updatedCmInfo + .SaveJsonToDirectory(filePath.Directory, fileNameWithoutExtension) + .ConfigureAwait(false); - var image = modelVersion - .Images - ?.FirstOrDefault(img => LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url))); + var image = modelVersion.Images?.FirstOrDefault( + img => LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) + ); if (image == null) return updatedCmInfo; From 90932c4826a614be1c69e2f499bc2740889181e3 Mon Sep 17 00:00:00 2001 From: JT Date: Mon, 25 Dec 2023 00:03:02 -0800 Subject: [PATCH 2/2] don't set the prop on first load --- .../ViewModels/OutputsPageViewModel.cs | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs index 03ca531c..b49ff6d1 100644 --- a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs @@ -135,7 +135,7 @@ public partial class OutputsPageViewModel : PageViewModelBase delay: TimeSpan.FromMilliseconds(250) ); - RefreshCategories(); + RefreshCategories(false); } public override void OnLoaded() @@ -255,7 +255,7 @@ public partial class OutputsPageViewModel : PageViewModelBase public void Refresh() { - Dispatcher.UIThread.Post(RefreshCategories); + Dispatcher.UIThread.Post(() => RefreshCategories()); Dispatcher.UIThread.Post(OnLoaded); } @@ -510,7 +510,7 @@ public partial class OutputsPageViewModel : PageViewModelBase } } - private void RefreshCategories() + private void RefreshCategories(bool updateProperty = true) { if (Design.IsDesignMode) return; @@ -553,7 +553,18 @@ public partial class OutputsPageViewModel : PageViewModelBase ); Categories = new ObservableCollection(packageCategories); - selectedCategory = - Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name) ?? Categories.First(); + + if (updateProperty) + { + SelectedCategory = + Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name) + ?? Categories.First(); + } + else + { + selectedCategory = + Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name) + ?? Categories.First(); + } } }