using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using NLog; using Refit; using SkiaSharp; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; namespace StabilityMatrix.Avalonia.ViewModels.Base; /// /// Abstract base class for tab view models that generate images using ClientManager. /// This includes a progress reporter, image output view model, and generation virtual methods. /// [SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")] public abstract partial class InferenceGenerationViewModelBase : InferenceTabViewModelBase, IImageGalleryComponent { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; [JsonPropertyName("ImageGallery")] public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; } [JsonIgnore] public ImageFolderCardViewModel ImageFolderCardViewModel { get; } [JsonIgnore] public ProgressViewModel OutputProgress { get; } = new(); [JsonIgnore] public IInferenceClientManager ClientManager { get; } /// protected InferenceGenerationViewModelBase( ServiceManager vmFactory, IInferenceClientManager inferenceClientManager, INotificationService notificationService ) { this.notificationService = notificationService; ClientManager = inferenceClientManager; ImageGalleryCardViewModel = vmFactory.Get(); ImageFolderCardViewModel = vmFactory.Get(); GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); } /// /// Builds the image generation prompt /// protected virtual void BuildPrompt(BuildPromptEventArgs args) { } /// /// Runs a generation task /// /// Thrown if args.Parameters or args.Project are null protected async Task RunGeneration( ImageGenerationEventArgs args, CancellationToken cancellationToken ) { var client = args.Client; var nodes = args.Nodes; // Checks if (args.Parameters is null) throw new InvalidOperationException("Parameters is null"); if (args.Project is null) throw new InvalidOperationException("Project is null"); if (args.OutputNodeNames.Count == 0) throw new InvalidOperationException("OutputNodeNames is empty"); if (client.OutputImagesDir is null) throw new InvalidOperationException("OutputImagesDir is null"); // Connect preview image handler client.PreviewImageReceived += OnPreviewImageReceived; ComfyTask? promptTask = null; try { // Register to interrupt if user cancels cancellationToken.Register(() => { Logger.Info("Cancelling prompt"); client .InterruptPromptAsync(new CancellationTokenSource(5000).Token) .SafeFireAndForget(); }); try { promptTask = await client.QueuePromptAsync(nodes, cancellationToken); } catch (ApiException e) { Logger.Warn(e, "Api exception while queuing prompt"); await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync(); return; } // Register progress handler promptTask.ProgressUpdate += OnProgressUpdateReceived; // Wait for prompt to finish await promptTask.Task.WaitAsync(cancellationToken); Logger.Trace($"Prompt task {promptTask.Id} finished"); // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( promptTask.Id, cancellationToken ); ImageGalleryCardViewModel.ImageSources.Clear(); if ( !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null ) { // No images match notificationService.Show("No output", "Did not receive any output images"); return; } await ProcessOutputImages(images, args); } finally { // Disconnect progress handler client.PreviewImageReceived -= OnPreviewImageReceived; // Clear progress OutputProgress.Value = 0; OutputProgress.Text = ""; ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage = null; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false; // Cleanup tasks promptTask?.Dispose(); } } /// /// Handles image output metadata for generation runs /// private async Task ProcessOutputImages( IEnumerable images, ImageGenerationEventArgs args ) { // Write metadata to images var outputImages = new List(); foreach ( var filePath in images.Select(image => image.ToFilePath(args.Client.OutputImagesDir!)) ) { var bytesWithMetadata = PngDataHelper.AddMetadata( await filePath.ReadAllBytesAsync(), args.Parameters!, args.Project! ); await using (var outputStream = filePath.Info.OpenWrite()) { await outputStream.WriteAsync(bytesWithMetadata); await outputStream.FlushAsync(); } outputImages.Add(new ImageSource(filePath)); EventManager.Instance.OnImageFileAdded(filePath); } // Download all images to make grid, if multiple if (outputImages.Count > 1) { var loadedImages = outputImages .Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead())) .ToImmutableArray(); var grid = ImageProcessor.CreateImageGrid(loadedImages); var gridBytes = grid.Encode().ToArray(); var gridBytesWithMetadata = PngDataHelper.AddMetadata( gridBytes, args.Parameters!, args.Project! ); // Save to disk var lastName = outputImages.Last().LocalFile?.Info.Name; var gridPath = args.Client.OutputImagesDir!.JoinFile($"grid-{lastName}"); await using (var fileStream = gridPath.Info.OpenWrite()) { await fileStream.WriteAsync(gridBytesWithMetadata); } // Insert to start of images var gridImage = new ImageSource(gridPath); // Preload await gridImage.GetBitmapAsync(); ImageGalleryCardViewModel.ImageSources.Add(gridImage); EventManager.Instance.OnImageFileAdded(gridPath); } // Add rest of images foreach (var img in outputImages) { // Preload await img.GetBitmapAsync(); ImageGalleryCardViewModel.ImageSources.Add(img); } } /// /// Implementation for Generate Image /// protected virtual Task GenerateImageImpl( GenerateOverrides overrides, CancellationToken cancellationToken ) { return Task.CompletedTask; } /// /// Command for the Generate Image button /// /// Optional overrides (side buttons) /// Cancellation token [RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)] private async Task GenerateImage( GenerateFlags options = default, CancellationToken cancellationToken = default ) { try { var overrides = GenerateOverrides.FromFlags(options); await GenerateImageImpl(overrides, cancellationToken); } catch (OperationCanceledException) { Logger.Debug($"Image Generation Canceled"); } } /// /// Handles the preview image received event from the websocket. /// Updates the preview image in the image gallery. /// protected virtual void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args) { ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes); } /// /// Handles the progress update received event from the websocket. /// Updates the progress view model. /// protected virtual void OnProgressUpdateReceived( object? sender, ComfyProgressUpdateEventArgs args ) { Dispatcher.UIThread.Post(() => { OutputProgress.Value = args.Value; OutputProgress.Maximum = args.Maximum; OutputProgress.IsIndeterminate = false; OutputProgress.Text = $"({args.Value} / {args.Maximum})" + (args.RunningNode != null ? $" {args.RunningNode}" : ""); }); } public class ImageGenerationEventArgs : EventArgs { public required ComfyClient Client { get; init; } public required NodeDictionary Nodes { get; init; } public required IReadOnlyList OutputNodeNames { get; init; } public GenerationParameters? Parameters { get; set; } public InferenceProjectDocument? Project { get; set; } } public class BuildPromptEventArgs : EventArgs { public ComfyNodeBuilder Builder { get; } = new(); public GenerateOverrides Overrides { get; set; } = new(); } [Flags] public enum GenerateFlags { None = 0, HiresFixEnable = 1 << 1, HiresFixDisable = 1 << 2, UseCurrentSeed = 1 << 3, UseRandomSeed = 1 << 4 } public class GenerateOverrides { public bool? IsHiresFixEnabled { get; set; } public bool? UseCurrentSeed { get; set; } public static GenerateOverrides FromFlags(GenerateFlags flags) { var overrides = new GenerateOverrides() { IsHiresFixEnabled = flags.HasFlag(GenerateFlags.HiresFixEnable) ? true : flags.HasFlag(GenerateFlags.HiresFixDisable) ? false : null, UseCurrentSeed = flags.HasFlag(GenerateFlags.UseCurrentSeed) ? true : flags.HasFlag(GenerateFlags.UseRandomSeed) ? false : null }; return overrides; } } }