Browse Source

Add back some stuff removed in 81652930

(cherry picked from commit 00e0bfc3c3)
pull/463/head
Ionite 9 months ago
parent
commit
a3d95abf2d
  1. 64
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

64
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -1,11 +1,11 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Management;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
@ -15,7 +15,6 @@ using Avalonia.Controls.Notifications;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using ExifLibrary;
using MetadataExtractor.Formats.Exif;
using NLog;
using Refit;
using SkiaSharp;
@ -27,7 +26,6 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Animation;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
@ -297,14 +295,18 @@ public abstract partial class InferenceGenerationViewModelBase
Task.Run(
async () =>
{
var delayTime = 250 - (int)timer.ElapsedMilliseconds;
if (delayTime > 0)
try
{
await Task.Delay(delayTime, cancellationToken);
var delayTime = 250 - (int)timer.ElapsedMilliseconds;
if (delayTime > 0)
{
await Task.Delay(delayTime, cancellationToken);
}
// ReSharper disable once AccessToDisposedClosure
AttachRunningNodeChangedHandler(promptTask);
}
// ReSharper disable once AccessToDisposedClosure
AttachRunningNodeChangedHandler(promptTask);
catch (TaskCanceledException) { }
},
cancellationToken
)
@ -328,10 +330,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(promptTask.Id, cancellationToken);
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
|| images is not { Count: > 0 }
)
if (imageOutputs.Values.All(images => images is null or { Count: 0 }))
{
// No images match
notificationService.Show(
@ -350,7 +349,7 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGalleryCardViewModel.ImageSources.Clear();
}
var outputImages = await ProcessOutputImages(images, args);
var outputImages = await ProcessAllOutputImages(imageOutputs, args);
var notificationImage = outputImages.FirstOrDefault()?.LocalFile;
@ -380,12 +379,34 @@ public abstract partial class InferenceGenerationViewModelBase
}
}
private async Task<IEnumerable<ImageSource>> ProcessAllOutputImages(
IReadOnlyDictionary<string, List<ComfyImage>?> images,
ImageGenerationEventArgs args
)
{
var results = new List<ImageSource>();
foreach (var (nodeName, imageList) in images)
{
if (imageList is null)
{
Logger.Warn("No images for node {NodeName}", nodeName);
continue;
}
results.AddRange(await ProcessOutputImages(imageList, args, nodeName.Replace('_', ' ')));
}
return results;
}
/// <summary>
/// Handles image output metadata for generation runs
/// </summary>
private async Task<List<ImageSource>> ProcessOutputImages(
IReadOnlyCollection<ComfyImage> images,
ImageGenerationEventArgs args
ImageGenerationEventArgs args,
string? imageLabel = null
)
{
var client = args.Client;
@ -441,7 +462,7 @@ public abstract partial class InferenceGenerationViewModelBase
images.Count
);
outputImages.Add(new ImageSource(filePath));
outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
else if (comfyImage.FileName.EndsWith(".webp"))
@ -470,7 +491,7 @@ public abstract partial class InferenceGenerationViewModelBase
fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "")
);
outputImages.Add(new ImageSource(filePath));
outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
else
@ -484,7 +505,7 @@ public abstract partial class InferenceGenerationViewModelBase
fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "")
);
outputImages.Add(new ImageSource(filePath));
outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
}
@ -554,7 +575,12 @@ public abstract partial class InferenceGenerationViewModelBase
}
catch (OperationCanceledException)
{
Logger.Debug($"Image Generation Canceled");
Logger.Debug("Image Generation Canceled");
}
catch (ValidationException e)
{
Logger.Debug("Image Generation Validation Error: {Message}", e.Message);
notificationService.Show("Validation Error", e.Message, NotificationType.Error);
}
}

Loading…
Cancel
Save