using System; using System.Collections.Generic; using System.Collections.Immutable; using System.ComponentModel.DataAnnotations; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls.Notifications; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using NLog; using Refit; using SkiaSharp; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Base; /// /// Abstract base class for tab view models that generate images using ClientManager. /// This includes a progress reporter, image output view model, and generation virtual methods. /// [SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")] public abstract partial class InferenceGenerationViewModelBase : InferenceTabViewModelBase, IImageGalleryComponent { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly ISettingsManager settingsManager; private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; [JsonPropertyName("ImageGallery")] public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; } [JsonIgnore] public ImageFolderCardViewModel ImageFolderCardViewModel { get; } [JsonIgnore] public ProgressViewModel OutputProgress { get; } = new(); [JsonIgnore] public IInferenceClientManager ClientManager { get; } /// protected InferenceGenerationViewModelBase( ServiceManager vmFactory, IInferenceClientManager inferenceClientManager, INotificationService notificationService, ISettingsManager settingsManager ) : base(notificationService) { this.notificationService = notificationService; this.settingsManager = settingsManager; this.vmFactory = vmFactory; ClientManager = inferenceClientManager; ImageGalleryCardViewModel = vmFactory.Get(); ImageFolderCardViewModel = vmFactory.Get(); GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); } /// /// Write an image to the default output folder /// protected Task WriteOutputImageAsync( Stream imageStream, ImageGenerationEventArgs args, int batchNum = 0, int batchTotal = 0, bool isGrid = false ) { var defaultOutputDir = settingsManager.ImagesInferenceDirectory; defaultOutputDir.Create(); return WriteOutputImageAsync(imageStream, defaultOutputDir, args, batchNum, batchTotal, isGrid); } /// /// Write an image to an output folder /// protected async Task WriteOutputImageAsync( Stream imageStream, DirectoryPath outputDir, ImageGenerationEventArgs args, int batchNum = 0, int batchTotal = 0, bool isGrid = false ) { var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat; var formatProvider = new FileNameFormatProvider { GenerationParameters = args.Parameters, ProjectType = args.Project?.ProjectType, ProjectName = ProjectFile?.NameWithoutExtension }; // Parse to format if ( string.IsNullOrEmpty(formatTemplateStr) || !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format) ) { // Fallback to default Logger.Warn("Failed to parse format template: {FormatTemplate}, using default", formatTemplateStr); format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider); } if (isGrid) { format = format.WithGridPrefix(); } if (batchNum >= 1 && batchTotal > 1) { format = format.WithBatchPostFix(batchNum, batchTotal); } var fileName = format.GetFileName(); var file = outputDir.JoinFile($"{fileName}.png"); // Until the file is free, keep adding _{i} to the end for (var i = 0; i < 100; i++) { if (!file.Exists) break; file = outputDir.JoinFile($"{fileName}_{i + 1}.png"); } // If that fails, append an 7-char uuid if (file.Exists) { var uuid = Guid.NewGuid().ToString("N")[..7]; file = outputDir.JoinFile($"{fileName}_{uuid}.png"); } await using var fileStream = file.Info.OpenWrite(); await imageStream.CopyToAsync(fileStream); return file; } /// /// Builds the image generation prompt /// protected virtual void BuildPrompt(BuildPromptEventArgs args) { } /// /// Gets ImageSources that need to be uploaded as inputs /// protected virtual IEnumerable GetInputImages() { return Enumerable.Empty(); } protected async Task UploadInputImages(ComfyClient client) { foreach (var image in GetInputImages()) { if (image.LocalFile is { } localFile) { var uploadName = await image.GetHashGuidFileNameAsync(); Logger.Debug("Uploading image {FileName} as {UploadName}", localFile.Name, uploadName); // For pngs, strip metadata since Pillow can't handle some valid files? if (localFile.Info.Extension.Equals(".png", StringComparison.OrdinalIgnoreCase)) { var bytes = PngDataHelper.RemoveMetadata(await localFile.ReadAllBytesAsync()); using var stream = new MemoryStream(bytes); await client.UploadImageAsync(stream, uploadName); } else { await using var stream = localFile.Info.OpenRead(); await client.UploadImageAsync(stream, uploadName); } } } } /// /// Runs a generation task /// /// Thrown if args.Parameters or args.Project are null protected async Task RunGeneration(ImageGenerationEventArgs args, CancellationToken cancellationToken) { var client = args.Client; var nodes = args.Nodes; // Checks if (args.Parameters is null) throw new InvalidOperationException("Parameters is null"); if (args.Project is null) throw new InvalidOperationException("Project is null"); if (args.OutputNodeNames.Count == 0) throw new InvalidOperationException("OutputNodeNames is empty"); if (client.OutputImagesDir is null) throw new InvalidOperationException("OutputImagesDir is null"); // Upload input images await UploadInputImages(client); // Connect preview image handler client.PreviewImageReceived += OnPreviewImageReceived; // Register to interrupt if user cancels var promptInterrupt = cancellationToken.Register(() => { Logger.Info("Cancelling prompt"); client .InterruptPromptAsync(new CancellationTokenSource(5000).Token) .SafeFireAndForget(ex => { Logger.Warn(ex, "Error while interrupting prompt"); }); }); ComfyTask? promptTask = null; try { var timer = Stopwatch.StartNew(); try { promptTask = await client.QueuePromptAsync(nodes, cancellationToken); } catch (ApiException e) { Logger.Warn(e, "Api exception while queuing prompt"); await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync(); return; } // Register progress handler promptTask.ProgressUpdate += OnProgressUpdateReceived; // Delay attaching running node change handler to not show indeterminate progress // if progress updates are received before the prompt starts Task.Run( async () => { try { var delayTime = 250 - (int)timer.ElapsedMilliseconds; if (delayTime > 0) { await Task.Delay(delayTime, cancellationToken); } // ReSharper disable once AccessToDisposedClosure AttachRunningNodeChangedHandler(promptTask); } catch (TaskCanceledException) { } }, cancellationToken ) .SafeFireAndForget(); // Wait for prompt to finish try { await promptTask.Task.WaitAsync(cancellationToken); Logger.Debug($"Prompt task {promptTask.Id} finished"); } catch (ComfyNodeException e) { Logger.Warn(e, "Comfy node exception while queuing prompt"); await DialogHelper .CreateJsonDialog(e.JsonData, "Comfy Error", "Node execution encountered an error") .ShowAsync(); return; } // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync(promptTask.Id, cancellationToken); if (imageOutputs.Values.All(images => images is null or { Count: 0 })) { // No images match notificationService.Show("No output", "Did not receive any output images", NotificationType.Warning); return; } // Disable cancellation await promptInterrupt.DisposeAsync(); if (args.ClearOutputImages) { ImageGalleryCardViewModel.ImageSources.Clear(); } await ProcessAllOutputImages(imageOutputs, args); } finally { // Disconnect progress handler client.PreviewImageReceived -= OnPreviewImageReceived; // Clear progress OutputProgress.ClearProgress(); ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage = null; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false; // Cleanup tasks promptTask?.Dispose(); } } private async Task ProcessAllOutputImages( IReadOnlyDictionary?> images, ImageGenerationEventArgs args ) { foreach (var (nodeName, imageList) in images) { if (imageList is null) { Logger.Warn("No images for node {NodeName}", nodeName); continue; } await ProcessOutputImages(imageList, args, nodeName.Replace('_', ' ')); } } /// /// Handles image output metadata for generation runs /// private async Task ProcessOutputImages( IReadOnlyCollection images, ImageGenerationEventArgs args, string? imageLabel = null ) { var client = args.Client; // Write metadata to images var outputImagesBytes = new List(); var outputImages = new List(); foreach (var (i, comfyImage) in images.Enumerate()) { Logger.Debug("Downloading image: {FileName}", comfyImage.FileName); var imageStream = await client.GetImageStreamAsync(comfyImage); using var ms = new MemoryStream(); await imageStream.CopyToAsync(ms); var imageArray = ms.ToArray(); outputImagesBytes.Add(imageArray); var parameters = args.Parameters!; var project = args.Project!; // Lock seed project.TryUpdateModel("Seed", model => model with { IsRandomizeEnabled = false }); // Seed and batch override for batches if (images.Count > 1 && project.ProjectType is InferenceProjectType.TextToImage) { project = (InferenceProjectDocument)project.Clone(); // Set batch size indexes project.TryUpdateModel( "BatchSize", node => { node[nameof(BatchSizeCardViewModel.BatchCount)] = 1; node[nameof(BatchSizeCardViewModel.IsBatchIndexEnabled)] = true; node[nameof(BatchSizeCardViewModel.BatchIndex)] = i + 1; return node; } ); } var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project); // Write using generated name var filePath = await WriteOutputImageAsync(new MemoryStream(bytesWithMetadata), args, i + 1, images.Count); outputImages.Add(new ImageSource(filePath) { Label = imageLabel }); EventManager.Instance.OnImageFileAdded(filePath); } // Download all images to make grid, if multiple if (outputImages.Count > 1) { var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray(); var project = args.Project!; // Lock seed project.TryUpdateModel("Seed", model => model with { IsRandomizeEnabled = false }); var grid = ImageProcessor.CreateImageGrid(loadedImages); var gridBytes = grid.Encode().ToArray(); var gridBytesWithMetadata = PngDataHelper.AddMetadata(gridBytes, args.Parameters!, args.Project!); // Save to disk var gridPath = await WriteOutputImageAsync(new MemoryStream(gridBytesWithMetadata), args, isGrid: true); // Insert to start of images var gridImage = new ImageSource(gridPath) { Label = imageLabel }; // Preload await gridImage.GetBitmapAsync(); ImageGalleryCardViewModel.ImageSources.Add(gridImage); EventManager.Instance.OnImageFileAdded(gridPath); } // Add rest of images foreach (var img in outputImages) { // Preload await img.GetBitmapAsync(); ImageGalleryCardViewModel.ImageSources.Add(img); } } /// /// Implementation for Generate Image /// protected virtual Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken) { return Task.CompletedTask; } /// /// Command for the Generate Image button /// /// Optional overrides (side buttons) /// Cancellation token [RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)] private async Task GenerateImage(GenerateFlags options = default, CancellationToken cancellationToken = default) { var overrides = GenerateOverrides.FromFlags(options); try { await GenerateImageImpl(overrides, cancellationToken); } catch (OperationCanceledException) { Logger.Debug("Image Generation Canceled"); } catch (ValidationException e) { Logger.Debug("Image Generation Validation Error: {Message}", e.Message); notificationService.Show("Validation Error", e.Message, NotificationType.Error); } } /// /// Shows a prompt and return false if client not connected /// protected async Task CheckClientConnectedWithPrompt() { if (ClientManager.IsConnected) return true; var vm = vmFactory.Get(); await vm.CreateDialog().ShowAsync(); return ClientManager.IsConnected; } /// /// Handles the preview image received event from the websocket. /// Updates the preview image in the image gallery. /// protected virtual void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args) { ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes); } /// /// Handles the progress update received event from the websocket. /// Updates the progress view model. /// protected virtual void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args) { Dispatcher .UIThread .Post(() => { OutputProgress.Value = args.Value; OutputProgress.Maximum = args.Maximum; OutputProgress.IsIndeterminate = false; OutputProgress.Text = $"({args.Value} / {args.Maximum})" + (args.RunningNode != null ? $" {args.RunningNode}" : ""); }); } private void AttachRunningNodeChangedHandler(ComfyTask comfyTask) { // Do initial update if (comfyTask.RunningNodesHistory.TryPeek(out var lastNode)) { OnRunningNodeChanged(comfyTask, lastNode); } comfyTask.RunningNodeChanged += OnRunningNodeChanged; } /// /// Handles the node executing updates received event from the websocket. /// protected virtual void OnRunningNodeChanged(object? sender, string? nodeName) { // Ignore if regular progress updates started if (sender is not ComfyTask { HasProgressUpdateStarted: false }) { return; } Dispatcher .UIThread .Post(() => { OutputProgress.IsIndeterminate = true; OutputProgress.Value = 100; OutputProgress.Maximum = 100; OutputProgress.Text = nodeName; }); } public class ImageGenerationEventArgs : EventArgs { public required ComfyClient Client { get; init; } public required NodeDictionary Nodes { get; init; } public required IReadOnlyList OutputNodeNames { get; init; } public GenerationParameters? Parameters { get; init; } public InferenceProjectDocument? Project { get; init; } public bool ClearOutputImages { get; init; } = true; } public class BuildPromptEventArgs : EventArgs { public ComfyNodeBuilder Builder { get; } = new(); public GenerateOverrides Overrides { get; init; } = new(); public long? SeedOverride { get; init; } public static implicit operator ModuleApplyStepEventArgs(BuildPromptEventArgs args) { var overrides = new Dictionary(); if (args.Overrides.IsHiresFixEnabled.HasValue) { overrides[typeof(HiresFixModule)] = args.Overrides.IsHiresFixEnabled.Value; } return new ModuleApplyStepEventArgs { Builder = args.Builder, IsEnabledOverrides = overrides }; } } }