using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Media.Imaging; using AvaloniaEdit.Document; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using NLog; using Refit; using SkiaSharp; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; #pragma warning disable CS0657 // Not a valid attribute location for this declaration namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceTextToImageView))] public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; public IInferenceClientManager ClientManager { get; } public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; } public PromptCardViewModel PromptCardViewModel { get; } public StackCardViewModel StackCardViewModel { get; } public UpscalerCardViewModel UpscalerCardViewModel => StackCardViewModel.GetCard().GetCard(); public SamplerCardViewModel HiresSamplerCardViewModel => StackCardViewModel.GetCard().GetCard(); public bool IsHiresFixEnabled => StackCardViewModel.GetCard().IsEnabled; public bool IsUpscaleEnabled => StackCardViewModel.GetCard(1).IsEnabled; [JsonIgnore] public ProgressViewModel OutputProgress { get; } = new(); [ObservableProperty] [property: JsonIgnore] private string? outputImageSource; public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, ServiceManager vmFactory ) { this.notificationService = notificationService; this.vmFactory = vmFactory; ClientManager = inferenceClientManager; // Get sub view models from service manager var seedCard = vmFactory.Get(); seedCard.GenerateNewSeed(); ImageGalleryCardViewModel = vmFactory.Get(); PromptCardViewModel = vmFactory.Get(); StackCardViewModel = vmFactory.Get(); StackCardViewModel.AddCards( new LoadableViewModelBase[] { // Model Card vmFactory.Get(), // Sampler vmFactory.Get(), // Hires Fix vmFactory.Get(stackExpander => { stackExpander.Title = "Hires Fix"; stackExpander.AddCards( new LoadableViewModelBase[] { // Hires Fix Upscaler vmFactory.Get(), // Hires Fix Sampler vmFactory.Get(samplerCard => { samplerCard.IsDimensionsEnabled = false; samplerCard.IsCfgScaleEnabled = false; samplerCard.IsSamplerSelectionEnabled = false; samplerCard.IsDenoiseStrengthEnabled = true; }) } ); }), vmFactory.Get(stackExpander => { stackExpander.Title = "Upscale"; stackExpander.AddCards( new LoadableViewModelBase[] { // Post processing upscaler vmFactory.Get(), }); }), // Seed seedCard, // Batch Size vmFactory.Get(), } ); GenerateImageCommand.WithNotificationErrorHandler(notificationService); } private (NodeDictionary prompt, string[] outputs) BuildPrompt() { using var _ = new CodeTimer(); var samplerCard = StackCardViewModel.GetCard(); var batchCard = StackCardViewModel.GetCard(); var modelCard = StackCardViewModel.GetCard(); var seedCard = StackCardViewModel.GetCard(); var prompt = new NodeDictionary(); var builder = new ComfyNodeBuilder(prompt); var checkpointLoader = prompt.AddNamedNode( new NamedComfyNode("CheckpointLoader") { ClassType = "CheckpointLoaderSimple", Inputs = new Dictionary { ["ckpt_name"] = modelCard.SelectedModelName } } ); var checkpointVae = checkpointLoader.GetOutput(2); var emptyLatentImage = prompt.AddNamedNode( new NamedComfyNode("EmptyLatentImage") { ClassType = "EmptyLatentImage", Inputs = new Dictionary { ["batch_size"] = batchCard.BatchSize, ["height"] = samplerCard.Height, ["width"] = samplerCard.Width, } } ); var positiveClip = prompt.AddNamedNode( new NamedComfyNode("PositiveCLIP") { ClassType = "CLIPTextEncode", Inputs = new Dictionary { ["clip"] = checkpointLoader.GetOutput(1), ["text"] = PromptCardViewModel.PromptDocument.Text, } } ); var negativeClip = prompt.AddNamedNode( new NamedComfyNode("NegativeCLIP") { ClassType = "CLIPTextEncode", Inputs = new Dictionary { ["clip"] = checkpointLoader.GetOutput(1), ["text"] = PromptCardViewModel.NegativePromptDocument.Text, } } ); var sampler = prompt.AddNamedNode( ComfyNodeBuilder.KSampler( "Sampler", checkpointLoader.GetOutput(0), Convert.ToUInt64(seedCard.Seed), samplerCard.Steps, samplerCard.CfgScale, samplerCard.SelectedSampler?.Name ?? throw new InvalidOperationException("Sampler not selected"), "normal", positiveClip.GetOutput(0), negativeClip.GetOutput(0), emptyLatentImage.GetOutput(0), samplerCard.DenoiseStrength ) ); var lastLatent = sampler.Output; var lastLatentWidth = samplerCard.Width; var lastLatentHeight = samplerCard.Height; var vaeDecoder = prompt.AddNamedNode( new NamedComfyNode("VAEDecoder") { ClassType = "VAEDecode", Inputs = new Dictionary { ["samples"] = lastLatent, ["vae"] = checkpointLoader.GetOutput(2) } } ); var saveImage = prompt.AddNamedNode( new NamedComfyNode("SaveImage") { ClassType = "SaveImage", Inputs = new Dictionary { ["filename_prefix"] = "SM-Inference", ["images"] = vaeDecoder.GetOutput(0) } } ); // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node if (IsHiresFixEnabled) { var hiresUpscalerCard = UpscalerCardViewModel; var hiresSamplerCard = HiresSamplerCardViewModel; // Requested upscale to this size var hiresWidth = (int)Math.Floor(lastLatentWidth * hiresUpscalerCard.Scale); var hiresHeight = (int)Math.Floor(lastLatentHeight * hiresUpscalerCard.Scale); LatentNodeConnection hiresLatent; // Select between latent upscale and normal upscale based on the upscale method var selectedUpscaler = hiresUpscalerCard.SelectedUpscaler!.Value; if (selectedUpscaler.Type == ComfyUpscalerType.None) { // If no upscaler selected or none, just reroute the latent image hiresLatent = sampler.Output; } else { // Otherwise upscale the latent image hiresLatent = builder.Group_UpscaleToLatent("HiresFix", lastLatent, checkpointVae, selectedUpscaler, hiresWidth, hiresHeight).Output; } var hiresSampler = prompt.AddNamedNode( ComfyNodeBuilder.KSampler( "HiresSampler", checkpointLoader.GetOutput(0), Convert.ToUInt64(seedCard.Seed), hiresSamplerCard.Steps, hiresSamplerCard.CfgScale, // Use hires sampler name if not null, otherwise use the normal sampler name hiresSamplerCard.SelectedSampler?.Name ?? samplerCard.SelectedSampler?.Name ?? throw new InvalidOperationException("Sampler not selected"), "normal", positiveClip.GetOutput(0), negativeClip.GetOutput(0), hiresLatent, hiresSamplerCard.DenoiseStrength ) ); // Set as last latent lastLatent = hiresSampler.Output; lastLatentWidth = hiresWidth; lastLatentHeight = hiresHeight; // Reroute the VAEDecoder's input to be from the hires sampler vaeDecoder.Inputs["samples"] = lastLatent; } // If upscale is enabled, add another upscale group if (IsUpscaleEnabled) { var postUpscalerCard = StackCardViewModel.GetCard(1) .GetCard(); var upscaleWidth = (int)Math.Floor(lastLatentWidth * postUpscalerCard.Scale); var upscaleHeight = (int)Math.Floor(lastLatentHeight * postUpscalerCard.Scale); // Build group var postUpscaleGroup = builder.Group_UpscaleToImage("PostUpscale", lastLatent, checkpointVae, postUpscalerCard.SelectedUpscaler!.Value, upscaleWidth, upscaleHeight); // Remove the original vae decoder prompt.Remove(vaeDecoder.Name); // Set as the input for save image saveImage.Inputs["images"] = postUpscaleGroup.Output; } prompt.NormalizeConnectionTypes(); return (prompt, new[] { saveImage.Name }); } private void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args) { OutputProgress.Value = args.Value; OutputProgress.Maximum = args.Maximum; OutputProgress.IsIndeterminate = false; OutputProgress.Text = $"({args.Value} / {args.Maximum})" + (args.RunningNode != null ? $" {args.RunningNode}" : ""); } private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args) { // Decode to bitmap using var stream = new MemoryStream(args.ImageBytes); var bitmap = new Bitmap(stream); ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage = bitmap; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = true; } private async Task GenerateImageImpl(CancellationToken cancellationToken = default) { if (!ClientManager.IsConnected) { notificationService.Show("Client not connected", "Please connect first"); return; } // If enabled, randomize the seed var seedCard = StackCardViewModel.GetCard(); if (seedCard.IsRandomizeEnabled) { seedCard.GenerateNewSeed(); } var client = ClientManager.Client; var (nodes, outputNodeNames) = BuildPrompt(); // Connect progress handler // client.ProgressUpdateReceived += OnProgressUpdateReceived; client.PreviewImageReceived += OnPreviewImageReceived; ComfyTask? promptTask = null; try { // Register to interrupt if user cancels cancellationToken.Register(() => { Logger.Info("Cancelling prompt"); client .InterruptPromptAsync(new CancellationTokenSource(5000).Token) .SafeFireAndForget(); }); try { promptTask = await client.QueuePromptAsync(nodes, cancellationToken); } catch (ApiException e) { Logger.Warn(e, "Api exception while queuing prompt"); await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync(); return; } // Register progress handler promptTask.ProgressUpdate += OnProgressUpdateReceived; // Wait for prompt to finish await promptTask.Task.WaitAsync(cancellationToken); Logger.Trace($"Prompt task {promptTask.Id} finished"); // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( promptTask.Id, cancellationToken ); ImageGalleryCardViewModel.ImageSources.Clear(); var images = imageOutputs[outputNodeNames[0]]; if (images is null) return; List outputImages; // Use local file path if available, otherwise use remote URL if (client.OutputImagesDir is { } outputPath) { outputImages = images .Select(i => new ImageSource(i.ToFilePath(outputPath))) .ToList(); } else { outputImages = images .Select(i => new ImageSource(i.ToUri(client.BaseAddress))) .ToList(); } // Download all images to make grid, if multiple if (outputImages.Count > 1) { var loadedImages = outputImages .Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead())) .ToImmutableArray(); var grid = ImageProcessor.CreateImageGrid(loadedImages); // Save to disk var lastName = outputImages.Last().LocalFile?.Info.Name; var gridPath = client.OutputImagesDir!.JoinFile($"grid-{lastName}"); await using var fileStream = gridPath.Info.OpenWrite(); await fileStream.WriteAsync(grid.Encode().ToArray(), cancellationToken); // Insert to start of images ImageGalleryCardViewModel.ImageSources.Add(new ImageSource(gridPath)); } // Add rest of images ImageGalleryCardViewModel.ImageSources.AddRange(outputImages); // Preload all images await Task.WhenAll(outputImages.Select(i => i.BitmapAsync)); } finally { // Disconnect progress handler OutputProgress.Value = 0; OutputProgress.Text = ""; ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage = null; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false; // client.ProgressUpdateReceived -= OnProgressUpdateReceived; promptTask?.Dispose(); client.PreviewImageReceived -= OnPreviewImageReceived; } } [RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)] private async Task GenerateImage(CancellationToken cancellationToken = default) { try { await GenerateImageImpl(cancellationToken); } catch (OperationCanceledException e) { Logger.Debug($"[Image Generation Canceled] {e.Message}"); } } }