diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index 8b4678a3..f73fa6f5 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -1,11 +1,11 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.ComponentModel.DataAnnotations; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; -using System.Management; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; @@ -15,7 +15,6 @@ using Avalonia.Controls.Notifications; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using ExifLibrary; -using MetadataExtractor.Formats.Exif; using NLog; using Refit; using SkiaSharp; @@ -27,7 +26,6 @@ using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; -using StabilityMatrix.Core.Animation; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; @@ -320,14 +318,18 @@ public abstract partial class InferenceGenerationViewModelBase Task.Run( async () => { - var delayTime = 250 - (int)timer.ElapsedMilliseconds; - if (delayTime > 0) + try { - await Task.Delay(delayTime, cancellationToken); + var delayTime = 250 - (int)timer.ElapsedMilliseconds; + if (delayTime > 0) + { + await Task.Delay(delayTime, cancellationToken); + } + + // ReSharper disable once AccessToDisposedClosure + AttachRunningNodeChangedHandler(promptTask); } - - // ReSharper disable once AccessToDisposedClosure - AttachRunningNodeChangedHandler(promptTask); + catch (TaskCanceledException) { } }, cancellationToken ) @@ -351,10 +353,7 @@ public abstract partial class InferenceGenerationViewModelBase // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync(promptTask.Id, cancellationToken); - if ( - !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) - || images is not { Count: > 0 } - ) + if (imageOutputs.Values.All(images => images is null or { Count: 0 })) { // No images match notificationService.Show( @@ -373,7 +372,7 @@ public abstract partial class InferenceGenerationViewModelBase ImageGalleryCardViewModel.ImageSources.Clear(); } - var outputImages = await ProcessOutputImages(images, args); + var outputImages = await ProcessAllOutputImages(imageOutputs, args); var notificationImage = outputImages.FirstOrDefault()?.LocalFile; @@ -403,12 +402,34 @@ public abstract partial class InferenceGenerationViewModelBase } } + private async Task> ProcessAllOutputImages( + IReadOnlyDictionary?> images, + ImageGenerationEventArgs args + ) + { + var results = new List(); + + foreach (var (nodeName, imageList) in images) + { + if (imageList is null) + { + Logger.Warn("No images for node {NodeName}", nodeName); + continue; + } + + results.AddRange(await ProcessOutputImages(imageList, args, nodeName.Replace('_', ' '))); + } + + return results; + } + /// /// Handles image output metadata for generation runs /// private async Task> ProcessOutputImages( IReadOnlyCollection images, - ImageGenerationEventArgs args + ImageGenerationEventArgs args, + string? imageLabel = null ) { var client = args.Client; @@ -464,7 +485,7 @@ public abstract partial class InferenceGenerationViewModelBase images.Count ); - outputImages.Add(new ImageSource(filePath)); + outputImages.Add(new ImageSource(filePath) { Label = imageLabel }); EventManager.Instance.OnImageFileAdded(filePath); } else if (comfyImage.FileName.EndsWith(".webp")) @@ -493,7 +514,7 @@ public abstract partial class InferenceGenerationViewModelBase fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "") ); - outputImages.Add(new ImageSource(filePath)); + outputImages.Add(new ImageSource(filePath) { Label = imageLabel }); EventManager.Instance.OnImageFileAdded(filePath); } else @@ -507,7 +528,7 @@ public abstract partial class InferenceGenerationViewModelBase fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "") ); - outputImages.Add(new ImageSource(filePath)); + outputImages.Add(new ImageSource(filePath) { Label = imageLabel }); EventManager.Instance.OnImageFileAdded(filePath); } } @@ -577,7 +598,12 @@ public abstract partial class InferenceGenerationViewModelBase } catch (OperationCanceledException) { - Logger.Debug($"Image Generation Canceled"); + Logger.Debug("Image Generation Canceled"); + } + catch (ValidationException e) + { + Logger.Debug("Image Generation Validation Error: {Message}", e.Message); + notificationService.Show("Validation Error", e.Message, NotificationType.Error); } }