From 35644f7d1df25ca946740e33401c35c3f156ea52 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sat, 2 Dec 2023 19:56:49 -0500 Subject: [PATCH] Add multiple image outputs and labels --- .../Models/ImageSource.cs | 5 +++ .../Base/InferenceGenerationViewModelBase.cs | 32 ++++++++++++++----- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/StabilityMatrix.Avalonia/Models/ImageSource.cs b/StabilityMatrix.Avalonia/Models/ImageSource.cs index d5a33b87..ab1c1023 100644 --- a/StabilityMatrix.Avalonia/Models/ImageSource.cs +++ b/StabilityMatrix.Avalonia/Models/ImageSource.cs @@ -29,6 +29,11 @@ public record ImageSource : IDisposable /// public Bitmap? Bitmap { get; set; } + /// + /// Optional label for the image + /// + public string? Label { get; set; } + public ImageSource(FilePath localFile) { LocalFile = localFile; diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index d27b1500..f493cb7a 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -326,10 +326,7 @@ public abstract partial class InferenceGenerationViewModelBase cancellationToken ); - if ( - !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) - || images is not { Count: > 0 } - ) + if (imageOutputs.Values.All(images => images is null or { Count: 0 })) { // No images match notificationService.Show( @@ -348,7 +345,7 @@ public abstract partial class InferenceGenerationViewModelBase ImageGalleryCardViewModel.ImageSources.Clear(); } - await ProcessOutputImages(images, args); + await ProcessAllOutputImages(imageOutputs, args); } finally { @@ -366,12 +363,30 @@ public abstract partial class InferenceGenerationViewModelBase } } + private async Task ProcessAllOutputImages( + IReadOnlyDictionary?> images, + ImageGenerationEventArgs args + ) + { + foreach (var (nodeName, imageList) in images) + { + if (imageList is null) + { + Logger.Warn("No images for node {NodeName}", nodeName); + continue; + } + + await ProcessOutputImages(imageList, args, nodeName.Replace('_', ' ')); + } + } + /// /// Handles image output metadata for generation runs /// private async Task ProcessOutputImages( IReadOnlyCollection images, - ImageGenerationEventArgs args + ImageGenerationEventArgs args, + string? imageLabel = null ) { var client = args.Client; @@ -428,7 +443,7 @@ public abstract partial class InferenceGenerationViewModelBase images.Count ); - outputImages.Add(new ImageSource(filePath)); + outputImages.Add(new ImageSource(filePath) { Label = imageLabel }); EventManager.Instance.OnImageFileAdded(filePath); } @@ -462,7 +477,8 @@ public abstract partial class InferenceGenerationViewModelBase ); // Insert to start of images - var gridImage = new ImageSource(gridPath); + var gridImage = new ImageSource(gridPath) { Label = imageLabel }; + // Preload await gridImage.GetBitmapAsync(); ImageGalleryCardViewModel.ImageSources.Add(gridImage);