using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Media.Imaging; using AvaloniaEdit.Document; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using NLog; using Refit; using SkiaSharp; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; #pragma warning disable CS0657 // Not a valid attribute location for this declaration namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceTextToImageView))] public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; public IInferenceClientManager ClientManager { get; } public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; } public PromptCardViewModel PromptCardViewModel { get; } public StackCardViewModel StackCardViewModel { get; } public UpscalerCardViewModel UpscalerCardViewModel => StackCardViewModel .GetCard() .GetCard(); public SamplerCardViewModel HiresSamplerCardViewModel => StackCardViewModel .GetCard() .GetCard(); public bool IsHiresFixEnabled => StackCardViewModel.GetCard().IsEnabled; [JsonIgnore] public ProgressViewModel OutputProgress { get; } = new(); [ObservableProperty] [property: JsonIgnore] private string? outputImageSource; public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, ServiceManager vmFactory ) { this.notificationService = notificationService; this.vmFactory = vmFactory; ClientManager = inferenceClientManager; // Get sub view models from service manager var seedCard = vmFactory.Get(); seedCard.GenerateNewSeed(); ImageGalleryCardViewModel = vmFactory.Get(); PromptCardViewModel = vmFactory.Get(); StackCardViewModel = vmFactory.Get(); StackCardViewModel.AddCards(new LoadableViewModelBase[] { // Model Card vmFactory.Get(), // Sampler vmFactory.Get(), // Hires Fix vmFactory.Get(stackExpander => { stackExpander.Title = "Hires Fix"; stackExpander.AddCards(new LoadableViewModelBase[] { // Hires Fix Upscaler vmFactory.Get(), // Hires Fix Sampler vmFactory.Get(samplerCard => { samplerCard.IsDimensionsEnabled = false; samplerCard.IsCfgScaleEnabled = false; samplerCard.IsSamplerSelectionEnabled = false; samplerCard.IsDenoiseStrengthEnabled = true; }) }); }), // Seed seedCard, // Batch Size vmFactory.Get(), }); GenerateImageCommand.WithNotificationErrorHandler(notificationService); } private (NodeDictionary prompt, string[] outputs) BuildPrompt() { using var _ = new CodeTimer(); var samplerCard = StackCardViewModel.GetCard(); var batchCard = StackCardViewModel.GetCard(); var modelCard = StackCardViewModel.GetCard(); var seedCard = StackCardViewModel.GetCard(); var prompt = new NodeDictionary(); var builder = new ComfyNodeBuilder(prompt); var checkpointLoader = prompt.AddNamedNode(new NamedComfyNode("CheckpointLoader") { ClassType = "CheckpointLoaderSimple", Inputs = new Dictionary { ["ckpt_name"] = modelCard.SelectedModelName } }); var checkpointVae = checkpointLoader.GetOutput(2); var emptyLatentImage = prompt.AddNamedNode(new NamedComfyNode("EmptyLatentImage") { ClassType = "EmptyLatentImage", Inputs = new Dictionary { ["batch_size"] = batchCard.BatchSize, ["height"] = samplerCard.Height, ["width"] = samplerCard.Width, } }); var positiveClip = prompt.AddNamedNode(new NamedComfyNode("PositiveCLIP") { ClassType = "CLIPTextEncode", Inputs = new Dictionary { ["clip"] = checkpointLoader.GetOutput(1), ["text"] = PromptCardViewModel.PromptDocument.Text, } }); var negativeClip = prompt.AddNamedNode(new NamedComfyNode("NegativeCLIP") { ClassType = "CLIPTextEncode", Inputs = new Dictionary { ["clip"] = checkpointLoader.GetOutput(1), ["text"] = PromptCardViewModel.NegativePromptDocument.Text, } }); var sampler = prompt.AddNamedNode(ComfyNodeBuilder.KSampler( "Sampler", checkpointLoader.GetOutput(0), Convert.ToUInt64(seedCard.Seed), samplerCard.Steps, samplerCard.CfgScale, samplerCard.SelectedSampler?.Name ?? throw new InvalidOperationException("Sampler not selected"), "normal", positiveClip.GetOutput(0), negativeClip.GetOutput(0), emptyLatentImage.GetOutput(0), samplerCard.DenoiseStrength)); var vaeDecoder = prompt.AddNamedNode(new NamedComfyNode("VAEDecoder") { ClassType = "VAEDecode", Inputs = new Dictionary { ["samples"] = sampler.GetOutput(0), ["vae"] = checkpointLoader.GetOutput(2) } }); var saveImage = prompt.AddNamedNode(new NamedComfyNode("SaveImage") { ClassType = "SaveImage", Inputs = new Dictionary { ["filename_prefix"] = "SM-Inference", ["images"] = vaeDecoder.GetOutput(0) } }); // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node if (IsHiresFixEnabled) { var hiresUpscalerCard = UpscalerCardViewModel; var hiresSamplerCard = HiresSamplerCardViewModel; // Select between latent upscale and normal upscale based on the upscale method var selectedUpscaler = hiresUpscalerCard.SelectedUpscaler; LatentNodeConnection hiresOutput; if (selectedUpscaler?.Type == ComfyUpscalerType.Latent) { hiresOutput = prompt.AddNamedNode(new NamedComfyNode("LatentUpscale") { ClassType = "LatentUpscale", Inputs = new Dictionary { ["upscale_method"] = hiresUpscalerCard.SelectedUpscaler?.Name, ["width"] = samplerCard.Width * hiresUpscalerCard.Scale, ["height"] = samplerCard.Height * hiresUpscalerCard.Scale, ["crop"] = "disabled", ["samples"] = sampler.Output } }).GetOutput(0); } else if (selectedUpscaler?.Type == ComfyUpscalerType.ESRGAN) { // Convert to image space var samplerImage = builder.Lambda_LatentToImage(sampler.Output, checkpointVae); // Do group upscale var modelUpscaler = builder.Group_UpscaleWithModel("Upscaler", selectedUpscaler.Value.Name, samplerImage); // Convert back to latent space hiresOutput = builder.Lambda_ImageToLatent(modelUpscaler.Output, checkpointVae); } else { // If no upscaler selected or none, just reroute the latent image hiresOutput = sampler.Output; } var hiresSampler = prompt.AddNamedNode(ComfyNodeBuilder.KSampler( "HiresSampler", checkpointLoader.GetOutput(0), Convert.ToUInt64(seedCard.Seed), hiresSamplerCard.Steps, hiresSamplerCard.CfgScale, // Use hires sampler name if not null, otherwise use the normal sampler name hiresSamplerCard.SelectedSampler?.Name ?? samplerCard.SelectedSampler?.Name ?? throw new InvalidOperationException("Sampler not selected"), "normal", positiveClip.GetOutput(0), negativeClip.GetOutput(0), hiresOutput, hiresSamplerCard.DenoiseStrength)); // Reroute the VAEDecoder's input to be from the hires sampler vaeDecoder.Inputs["samples"] = hiresSampler.Output; } prompt.NormalizeConnectionTypes(); return (prompt, new[] { saveImage.Name }); } private void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args) { OutputProgress.Value = args.Value; OutputProgress.Maximum = args.Maximum; OutputProgress.IsIndeterminate = false; OutputProgress.Text = $"({args.Value} / {args.Maximum})" + (args.RunningNode != null ? $" {args.RunningNode}" : ""); } private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args) { // Decode to bitmap using var stream = new MemoryStream(args.ImageBytes); var bitmap = new Bitmap(stream); ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage = bitmap; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = true; } private async Task GenerateImageImpl(CancellationToken cancellationToken = default) { if (!ClientManager.IsConnected) { notificationService.Show("Client not connected", "Please connect first"); return; } // If enabled, randomize the seed var seedCard = StackCardViewModel.GetCard(); if (seedCard.IsRandomizeEnabled) { seedCard.GenerateNewSeed(); } var client = ClientManager.Client; var (nodes, outputNodeNames) = BuildPrompt(); // Connect progress handler // client.ProgressUpdateReceived += OnProgressUpdateReceived; client.PreviewImageReceived += OnPreviewImageReceived; ComfyTask? promptTask = null; try { // Register to interrupt if user cancels cancellationToken.Register(() => { Logger.Info("Cancelling prompt"); client.InterruptPromptAsync(new CancellationTokenSource(5000).Token).SafeFireAndForget(); }); try { promptTask = await client.QueuePromptAsync(nodes, cancellationToken); } catch (ApiException e) { Logger.Warn(e, "Api exception while queuing prompt"); await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync(); return; } // Register progress handler promptTask.ProgressUpdate += OnProgressUpdateReceived; // Wait for prompt to finish await promptTask.Task.WaitAsync(cancellationToken); Logger.Trace($"Prompt task {promptTask.Id} finished"); // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( promptTask.Id, cancellationToken ); ImageGalleryCardViewModel.ImageSources.Clear(); var images = imageOutputs[outputNodeNames[0]]; if (images is null) return; List outputImages; // Use local file path if available, otherwise use remote URL if (client.OutputImagesDir is { } outputPath) { outputImages = images .Select(i => new ImageSource(i.ToFilePath(outputPath))) .ToList(); } else { outputImages = images .Select(i => new ImageSource(i.ToUri(client.BaseAddress))) .ToList(); } // Download all images to make grid, if multiple if (outputImages.Count > 1) { var loadedImages = outputImages.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead())).ToImmutableArray(); var grid = ImageProcessor.CreateImageGrid(loadedImages); // Save to disk var lastName = outputImages.Last().LocalFile?.Info.Name; var gridPath = client.OutputImagesDir!.JoinFile($"grid-{lastName}"); await using var fileStream = gridPath.Info.OpenWrite(); await fileStream.WriteAsync(grid.Encode().ToArray(), cancellationToken); // Insert to start of images ImageGalleryCardViewModel.ImageSources.Add(new ImageSource(gridPath)); } // Add rest of images ImageGalleryCardViewModel.ImageSources.AddRange(outputImages); } finally { // Disconnect progress handler OutputProgress.Value = 0; OutputProgress.Text = ""; ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage = null; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false; // client.ProgressUpdateReceived -= OnProgressUpdateReceived; promptTask?.Dispose(); client.PreviewImageReceived -= OnPreviewImageReceived; } } [RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)] private async Task GenerateImage(CancellationToken cancellationToken = default) { try { await GenerateImageImpl(cancellationToken); } catch (OperationCanceledException e) { Logger.Debug($"[Image Generation Canceled] {e.Message}"); } } }