using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; using Avalonia.Media.Imaging; using AvaloniaEdit.Document; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using NLog; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceTextToImageView))] public partial class InferenceTextToImageViewModel : ViewModelBase, ILoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; public IInferenceClientManager ClientManager { get; } // These are set in OnLoaded due to needing the vmFactory [NotNull] public SeedCardViewModel? SeedCardViewModel { get; private set; } [NotNull] public SamplerCardViewModel? SamplerCardViewModel { get; private set; } [NotNull] public SamplerCardViewModel? HiresFixSamplerCardViewModel { get; private set; } [NotNull] public ImageGalleryCardViewModel? ImageGalleryCardViewModel { get; private set; } public InferenceViewModel? Parent { get; set; } public TextDocument PromptDocument { get; } = new(); public TextDocument NegativePromptDocument { get; } = new(); [ObservableProperty] private string? selectedModelName; [ObservableProperty] private int batchSize = 1; public ProgressViewModel OutputProgress { get; } = new(); [ObservableProperty] private string? outputImageSource; public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, ServiceManager vmFactory) { this.notificationService = notificationService; this.vmFactory = vmFactory; // ReSharper disable twice NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract SeedCardViewModel ??= vmFactory.Get(); // OnPropertyChanged(nameof(SeedCardViewModel)); SamplerCardViewModel ??= vmFactory.Get(); // OnPropertyChanged(nameof(SamplerCardViewModel)); HiresFixSamplerCardViewModel ??= vmFactory.Get(vm => { vm.IsScaleSizeMode = true; vm.IsCfgScaleEnabled = false; vm.IsSamplerSelectionEnabled = false; vm.IsDenoiseStrengthEnabled = true; }); // OnPropertyChanged(nameof(HiresFixSamplerCardViewModel)); // ImageGalleryCardViewModel ??= vmFactory.Get(); ImageGalleryCardViewModel = new ImageGalleryCardViewModel(); // OnPropertyChanged(nameof(ImageGalleryCardViewModel)); SeedCardViewModel.GenerateNewSeed(); ClientManager = inferenceClientManager; } private Dictionary GetCurrentPrompt() { var prompt = new Dictionary { ["3"] = new() { ClassType = "KSampler", Inputs = new Dictionary { ["cfg"] = SamplerCardViewModel.CfgScale, ["denoise"] = 1, ["latent_image"] = new object[] { "5", 0 }, ["model"] = new object[] { "4", 0 }, ["negative"] = new object[] { "7", 0 }, ["positive"] = new object[] { "6", 0 }, ["sampler_name"] = SamplerCardViewModel.SelectedSampler, ["scheduler"] = "normal", ["seed"] = SeedCardViewModel.Seed, ["steps"] = SamplerCardViewModel.Steps } }, ["4"] = new() { ClassType = "CheckpointLoaderSimple", Inputs = new Dictionary { ["ckpt_name"] = SelectedModelName } }, ["5"] = new() { ClassType = "EmptyLatentImage", Inputs = new Dictionary { ["batch_size"] = BatchSize, ["height"] = SamplerCardViewModel.Height, ["width"] = SamplerCardViewModel.Width, } }, ["6"] = new() { ClassType = "CLIPTextEncode", Inputs = new Dictionary { ["clip"] = new object[] { "4", 1 }, ["text"] = PromptDocument.Text, } }, ["7"] = new() { ClassType = "CLIPTextEncode", Inputs = new Dictionary { ["clip"] = new object[] { "4", 1 }, ["text"] = NegativePromptDocument.Text, } }, ["8"] = new() { ClassType = "VAEDecode", Inputs = new Dictionary { ["samples"] = new object[] { "3", 0 }, ["vae"] = new object[] { "4", 2 } } }, ["9"] = new() { ClassType = "SaveImage", Inputs = new Dictionary { ["filename_prefix"] = "SM-Inference", ["images"] = new object[] { "8", 0 } } } }; return prompt; } private void OnProgressUpdateReceived(object? sender, ComfyWebSocketProgressData args) { OutputProgress.Value = args.Value; OutputProgress.Maximum = args.Max; OutputProgress.IsIndeterminate = false; } private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args) { // Decode to bitmap using var stream = new MemoryStream(args.ImageBytes); var bitmap = new Bitmap(stream); ImageGalleryCardViewModel.PreviewImage = bitmap; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = true; } private async Task GenerateImageImpl(CancellationToken cancellationToken = default) { if (!ClientManager.IsConnected) { notificationService.Show("Client not connected", "Please connect first"); return; } // If enabled, randomize the seed if (SeedCardViewModel.IsRandomizeEnabled) { SeedCardViewModel.GenerateNewSeed(); } var client = ClientManager.Client; var nodes = GetCurrentPrompt(); // Connect progress handler OutputProgress.IsIndeterminate = true; client.ProgressUpdateReceived += OnProgressUpdateReceived; client.PreviewImageReceived += OnPreviewImageReceived; try { var (response, promptTask) = await client.QueuePromptAsync(nodes, cancellationToken); Logger.Info(response); // Wait for prompt to finish await promptTask.WaitAsync(cancellationToken); Logger.Trace($"Prompt task {response.PromptId} finished"); // Get output images var outputs = await client .GetImagesForExecutedPromptAsync(response.PromptId, cancellationToken); // Only get the SaveImage from node 9 var images = outputs["9"]; if (images is null) return; ImageGalleryCardViewModel.ImageSources.Clear(); ImageGalleryCardViewModel.ImageSources.AddRange( images.Select(i => i.ToUri(client.BaseAddress).ToString())); } finally { // Disconnect progress handler OutputProgress.Value = 0; ImageGalleryCardViewModel.PreviewImage = null; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false; client.ProgressUpdateReceived -= OnProgressUpdateReceived; client.PreviewImageReceived -= OnPreviewImageReceived; } } [RelayCommand(IncludeCancelCommand = true)] private async Task GenerateImage(CancellationToken cancellationToken = default) { try { await GenerateImageImpl(cancellationToken); } catch (OperationCanceledException e) { Logger.Debug($"[Image Generation Canceled] {e.Message}"); } } /// public void LoadState(InferenceTextToImageModel state) { PromptDocument.Text = state.Prompt; NegativePromptDocument.Text = state.NegativePrompt; SelectedModelName = state.SelectedModelName; if (state.SeedCardState != null) { SeedCardViewModel.LoadState(state.SeedCardState); } if (state.SamplerCardState != null) { SamplerCardViewModel.LoadState(state.SamplerCardState); } } /// public InferenceTextToImageModel SaveState() { return new InferenceTextToImageModel { Prompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text, SelectedModelName = SelectedModelName, SeedCardState = SeedCardViewModel.SaveState(), SamplerCardState = SamplerCardViewModel.SaveState() }; } }