using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Avalonia.Media.Imaging;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using NLog;
using SkiaSharp;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
public partial class InferenceTextToImageViewModel : LoadableViewModelBase
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
public IInferenceClientManager ClientManager { get; }
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
public PromptCardViewModel PromptCardViewModel { get; }
public StackCardViewModel StackCardViewModel { get; }
public ProgressViewModel OutputProgress { get; } = new();
[property: JsonIgnore]
private string? outputImageSource;
public InferenceTextToImageViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ServiceManager<ViewModelBase> vmFactory
this.notificationService = notificationService;
this.vmFactory = vmFactory;
ClientManager = inferenceClientManager;
// Get sub view models from service manager
var seedCard = vmFactory.Get<SeedCardViewModel>();
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
PromptCardViewModel = vmFactory.Get<PromptCardViewModel>();
StackCardViewModel = vmFactory.Get<StackCardViewModel>();
StackCardViewModel.AddCards(new LoadableViewModelBase[]
// Model Card
// Sampler
// Hires Fix
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
stackExpander.Title = "Hires Fix";
stackExpander.AddCards(new LoadableViewModelBase[]
// Hires Fix Upscaler
// Hires Fix Sampler
vmFactory.Get<SamplerCardViewModel>(samplerCard =>
samplerCard.IsDimensionsEnabled = false;
samplerCard.IsCfgScaleEnabled = false;
samplerCard.IsSamplerSelectionEnabled = false;
samplerCard.IsDenoiseStrengthEnabled = true;
// Seed
// Batch Size
private Dictionary<string, ComfyNode> GetCurrentPrompt()
var sampler = StackCardViewModel.GetCard<SamplerCardViewModel>();
var batchCard = StackCardViewModel.GetCard<BatchSizeCardViewModel>();
var modelCard = StackCardViewModel.GetCard<ModelCardViewModel>();
var seedCard = StackCardViewModel.GetCard<SeedCardViewModel>();
var prompt = new Dictionary<string, ComfyNode>
["3"] = new()
ClassType = "KSampler",
Inputs = new Dictionary<string, object?>
["cfg"] = sampler.CfgScale,
["denoise"] = 1,
["latent_image"] = new object[] { "5", 0 },
["model"] = new object[] { "4", 0 },
["negative"] = new object[] { "7", 0 },
["positive"] = new object[] { "6", 0 },
["sampler_name"] = sampler.SelectedSampler?.Name,
["scheduler"] = "normal",
["seed"] = seedCard.Seed,
["steps"] = sampler.Steps
["4"] = new()
ClassType = "CheckpointLoaderSimple",
Inputs = new Dictionary<string, object?>
["ckpt_name"] = modelCard.SelectedModelName
["5"] = new()
ClassType = "EmptyLatentImage",
Inputs = new Dictionary<string, object?>
["batch_size"] = batchCard.BatchSize,
["height"] = sampler.Height,
["width"] = sampler.Width,
["6"] = new()
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
["clip"] = new object[] { "4", 1 },
["text"] = PromptCardViewModel.PromptDocument.Text,
["7"] = new()
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
["clip"] = new object[] { "4", 1 },
["text"] = PromptCardViewModel.NegativePromptDocument.Text,
["8"] = new()
ClassType = "VAEDecode",
Inputs = new Dictionary<string, object?>
["samples"] = new object[] { "3", 0 },
["vae"] = new object[] { "4", 2 }
["9"] = new()
ClassType = "SaveImage",
Inputs = new Dictionary<string, object?>
["filename_prefix"] = "SM-Inference",
["images"] = new object[] { "8", 0 }
return prompt;
private void OnProgressUpdateReceived(object? sender, ComfyWebSocketProgressData args)
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Max;
OutputProgress.IsIndeterminate = false;
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
// Decode to bitmap
using var stream = new MemoryStream(args.ImageBytes);
var bitmap = new Bitmap(stream);
ImageGalleryCardViewModel.PreviewImage = bitmap;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = true;
private async Task GenerateImageImpl(CancellationToken cancellationToken = default)
if (!ClientManager.IsConnected)
notificationService.Show("Client not connected", "Please connect first");
// If enabled, randomize the seed
var seedCard = StackCardViewModel.GetCard<SeedCardViewModel>();
if (seedCard.IsRandomizeEnabled)
var client = ClientManager.Client;
var nodes = GetCurrentPrompt();
// Connect progress handler
client.ProgressUpdateReceived += OnProgressUpdateReceived;
client.PreviewImageReceived += OnPreviewImageReceived;
var (response, promptTask) = await client.QueuePromptAsync(nodes, cancellationToken);
// Wait for prompt to finish
await promptTask.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {response.PromptId} finished");
// Get output images
var outputs = await client.GetImagesForExecutedPromptAsync(
// Only get the SaveImage images from node 9
var images = outputs["9"];
if (images is null) return;
List<ImageSource> outputImages;
// Use local file path if available, otherwise use remote URL
if (client.OutputImagesDir is { } outputPath)
outputImages = images
.Select(i => new ImageSource(i.ToFilePath(outputPath)))
outputImages = images
.Select(i => new ImageSource(i.ToUri(client.BaseAddress)))
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
var loadedImages = outputImages.Select(i =>
var grid = ImageProcessor.CreateImageGrid(loadedImages);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using var fileStream = gridPath.Info.OpenWrite();
await fileStream.WriteAsync(grid.Encode().ToArray(), cancellationToken);
// Insert to start of gallery
ImageGalleryCardViewModel.ImageSources.Add(new ImageSource(gridPath));
// var bitmaps = (await outputImages.SelectAsync(async i => await i.GetBitmapAsync())).ToImmutableArray();
// Insert rest of images
// Disconnect progress handler
OutputProgress.Value = 0;
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
client.ProgressUpdateReceived -= OnProgressUpdateReceived;
client.PreviewImageReceived -= OnPreviewImageReceived;
[RelayCommand(IncludeCancelCommand = true)]
private async Task GenerateImage(CancellationToken cancellationToken = default)
await GenerateImageImpl(cancellationToken);
catch (OperationCanceledException e)
Logger.Debug($"[Image Generation Canceled] {e.Message}");