using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
///
/// Abstract base class for tab view models that generate images using ClientManager.
/// This includes a progress reporter, image output view model, and generation virtual methods.
///
[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
public abstract partial class InferenceGenerationViewModelBase
: InferenceTabViewModelBase,
IImageGalleryComponent
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
[JsonPropertyName("ImageGallery")]
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
[JsonIgnore]
public ImageFolderCardViewModel ImageFolderCardViewModel { get; }
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
///
protected InferenceGenerationViewModelBase(
ServiceManager vmFactory,
IInferenceClientManager inferenceClientManager,
INotificationService notificationService
)
{
this.notificationService = notificationService;
ClientManager = inferenceClientManager;
ImageGalleryCardViewModel = vmFactory.Get();
ImageFolderCardViewModel = vmFactory.Get();
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
///
/// Builds the image generation prompt
///
protected virtual void BuildPrompt(BuildPromptEventArgs args) { }
///
/// Runs a generation task
///
/// Thrown if args.Parameters or args.Project are null
protected async Task RunGeneration(
ImageGenerationEventArgs args,
CancellationToken cancellationToken
)
{
var client = args.Client;
var nodes = args.Nodes;
// Checks
if (args.Parameters is null)
throw new InvalidOperationException("Parameters is null");
if (args.Project is null)
throw new InvalidOperationException("Project is null");
if (args.OutputNodeNames.Count == 0)
throw new InvalidOperationException("OutputNodeNames is empty");
if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null");
// Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
{
// Register to interrupt if user cancels
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
promptTask.Id,
cancellationToken
);
ImageGalleryCardViewModel.ImageSources.Clear();
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
await ProcessOutputImages(images, args);
}
finally
{
// Disconnect progress handler
client.PreviewImageReceived -= OnPreviewImageReceived;
// Clear progress
OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
// Cleanup tasks
promptTask?.Dispose();
}
}
///
/// Handles image output metadata for generation runs
///
private async Task ProcessOutputImages(
IEnumerable images,
ImageGenerationEventArgs args
)
{
// Write metadata to images
var outputImages = new List();
foreach (
var filePath in images.Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
)
{
var bytesWithMetadata = PngDataHelper.AddMetadata(
await filePath.ReadAllBytesAsync(),
args.Parameters!,
args.Project!
);
await using (var outputStream = filePath.Info.OpenWrite())
{
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
}
outputImages.Add(new ImageSource(filePath));
EventManager.Instance.OnImageFileAdded(filePath);
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImages
.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead()))
.ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
args.Parameters!,
args.Project!
);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = args.Client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using (var fileStream = gridPath.Info.OpenWrite())
{
await fileStream.WriteAsync(gridBytesWithMetadata);
}
// Insert to start of images
var gridImage = new ImageSource(gridPath);
// Preload
await gridImage.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(gridImage);
EventManager.Instance.OnImageFileAdded(gridPath);
}
// Add rest of images
foreach (var img in outputImages)
{
// Preload
await img.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(img);
}
}
///
/// Implementation for Generate Image
///
protected virtual Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
{
return Task.CompletedTask;
}
///
/// Command for the Generate Image button
///
/// Optional overrides (side buttons)
/// Cancellation token
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
GenerateFlags options = default,
CancellationToken cancellationToken = default
)
{
try
{
var overrides = GenerateOverrides.FromFlags(options);
await GenerateImageImpl(overrides, cancellationToken);
}
catch (OperationCanceledException)
{
Logger.Debug($"Image Generation Canceled");
}
}
///
/// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery.
///
protected virtual void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
///
/// Handles the progress update received event from the websocket.
/// Updates the progress view model.
///
protected virtual void OnProgressUpdateReceived(
object? sender,
ComfyProgressUpdateEventArgs args
)
{
Dispatcher.UIThread.Post(() =>
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
public class ImageGenerationEventArgs : EventArgs
{
public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList OutputNodeNames { get; init; }
public GenerationParameters? Parameters { get; set; }
public InferenceProjectDocument? Project { get; set; }
}
public class BuildPromptEventArgs : EventArgs
{
public ComfyNodeBuilder Builder { get; } = new();
public GenerateOverrides Overrides { get; set; } = new();
}
[Flags]
public enum GenerateFlags
{
None = 0,
HiresFixEnable = 1 << 1,
HiresFixDisable = 1 << 2,
UseCurrentSeed = 1 << 3,
UseRandomSeed = 1 << 4
}
public class GenerateOverrides
{
public bool? IsHiresFixEnabled { get; set; }
public bool? UseCurrentSeed { get; set; }
public static GenerateOverrides FromFlags(GenerateFlags flags)
{
var overrides = new GenerateOverrides()
{
IsHiresFixEnabled = flags.HasFlag(GenerateFlags.HiresFixEnable)
? true
: flags.HasFlag(GenerateFlags.HiresFixDisable)
? false
: null,
UseCurrentSeed = flags.HasFlag(GenerateFlags.UseCurrentSeed)
? true
: flags.HasFlag(GenerateFlags.UseRandomSeed)
? false
: null
};
return overrides;
}
}
}