diff --git a/StabilityMatrix.Avalonia/Models/ImageSource.cs b/StabilityMatrix.Avalonia/Models/ImageSource.cs
index d5a33b87..ab1c1023 100644
--- a/StabilityMatrix.Avalonia/Models/ImageSource.cs
+++ b/StabilityMatrix.Avalonia/Models/ImageSource.cs
@@ -29,6 +29,11 @@ public record ImageSource : IDisposable
///
public Bitmap? Bitmap { get; set; }
+ ///
+ /// Optional label for the image
+ ///
+ public string? Label { get; set; }
+
public ImageSource(FilePath localFile)
{
LocalFile = localFile;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
index d27b1500..f493cb7a 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
@@ -326,10 +326,7 @@ public abstract partial class InferenceGenerationViewModelBase
cancellationToken
);
- if (
- !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
- || images is not { Count: > 0 }
- )
+ if (imageOutputs.Values.All(images => images is null or { Count: 0 }))
{
// No images match
notificationService.Show(
@@ -348,7 +345,7 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGalleryCardViewModel.ImageSources.Clear();
}
- await ProcessOutputImages(images, args);
+ await ProcessAllOutputImages(imageOutputs, args);
}
finally
{
@@ -366,12 +363,30 @@ public abstract partial class InferenceGenerationViewModelBase
}
}
+ private async Task ProcessAllOutputImages(
+ IReadOnlyDictionary?> images,
+ ImageGenerationEventArgs args
+ )
+ {
+ foreach (var (nodeName, imageList) in images)
+ {
+ if (imageList is null)
+ {
+ Logger.Warn("No images for node {NodeName}", nodeName);
+ continue;
+ }
+
+ await ProcessOutputImages(imageList, args, nodeName.Replace('_', ' '));
+ }
+ }
+
///
/// Handles image output metadata for generation runs
///
private async Task ProcessOutputImages(
IReadOnlyCollection images,
- ImageGenerationEventArgs args
+ ImageGenerationEventArgs args,
+ string? imageLabel = null
)
{
var client = args.Client;
@@ -428,7 +443,7 @@ public abstract partial class InferenceGenerationViewModelBase
images.Count
);
- outputImages.Add(new ImageSource(filePath));
+ outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
@@ -462,7 +477,8 @@ public abstract partial class InferenceGenerationViewModelBase
);
// Insert to start of images
- var gridImage = new ImageSource(gridPath);
+ var gridImage = new ImageSource(gridPath) { Label = imageLabel };
+
// Preload
await gridImage.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(gridImage);