Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

314 lines
10 KiB

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Avalonia.Media.Imaging;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using NLog;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView))]
public partial class InferenceTextToImageViewModel : ViewModelBase, ILoadableState<InferenceTextToImageModel>
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
public IInferenceClientManager ClientManager { get; }
// These are set in OnLoaded due to needing the vmFactory
[NotNull] public SeedCardViewModel? SeedCardViewModel { get; private set; }
[NotNull] public SamplerCardViewModel? SamplerCardViewModel { get; private set; }
[NotNull] public SamplerCardViewModel? HiresFixSamplerCardViewModel { get; private set; }
[NotNull] public ImageGalleryCardViewModel? ImageGalleryCardViewModel { get; private set; }
public InferenceViewModel? Parent { get; set; }
public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new();
[ObservableProperty]
private string? selectedModelName;
[ObservableProperty]
private int batchSize = 1;
public ProgressViewModel OutputProgress { get; } = new();
[ObservableProperty]
private string? outputImageSource;
public InferenceTextToImageViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ServiceManager<ViewModelBase> vmFactory)
{
this.notificationService = notificationService;
this.vmFactory = vmFactory;
// ReSharper disable twice NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract
SeedCardViewModel ??= vmFactory.Get<SeedCardViewModel>();
// OnPropertyChanged(nameof(SeedCardViewModel));
SamplerCardViewModel ??= vmFactory.Get<SamplerCardViewModel>();
// OnPropertyChanged(nameof(SamplerCardViewModel));
HiresFixSamplerCardViewModel ??= vmFactory.Get<SamplerCardViewModel>(vm =>
{
vm.IsScaleSizeMode = true;
vm.IsCfgScaleEnabled = false;
vm.IsSamplerSelectionEnabled = false;
vm.IsDenoiseStrengthEnabled = true;
});
// OnPropertyChanged(nameof(HiresFixSamplerCardViewModel));
// ImageGalleryCardViewModel ??= vmFactory.Get<ImageGalleryCardViewModel>();
ImageGalleryCardViewModel = new ImageGalleryCardViewModel();
// OnPropertyChanged(nameof(ImageGalleryCardViewModel));
SeedCardViewModel.GenerateNewSeed();
ClientManager = inferenceClientManager;
}
private Dictionary<string, ComfyNode> GetCurrentPrompt()
{
var prompt = new Dictionary<string, ComfyNode>
{
["3"] = new()
{
ClassType = "KSampler",
Inputs = new Dictionary<string, object?>
{
["cfg"] = SamplerCardViewModel.CfgScale,
["denoise"] = 1,
["latent_image"] = new object[]
{
"5",
0
},
["model"] = new object[]
{
"4",
0
},
["negative"] = new object[]
{
"7",
0
},
["positive"] = new object[]
{
"6",
0
},
["sampler_name"] = SamplerCardViewModel.SelectedSampler,
["scheduler"] = "normal",
["seed"] = SeedCardViewModel.Seed,
["steps"] = SamplerCardViewModel.Steps
}
},
["4"] = new()
{
ClassType = "CheckpointLoaderSimple",
Inputs = new Dictionary<string, object?>
{
["ckpt_name"] = SelectedModelName
}
},
["5"] = new()
{
ClassType = "EmptyLatentImage",
Inputs = new Dictionary<string, object?>
{
["batch_size"] = BatchSize,
["height"] = SamplerCardViewModel.Height,
["width"] = SamplerCardViewModel.Width,
}
},
["6"] = new()
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
{
["clip"] = new object[]
{
"4",
1
},
["text"] = PromptDocument.Text,
}
},
["7"] = new()
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
{
["clip"] = new object[]
{
"4",
1
},
["text"] = NegativePromptDocument.Text,
}
},
["8"] = new()
{
ClassType = "VAEDecode",
Inputs = new Dictionary<string, object?>
{
["samples"] = new object[]
{
"3",
0
},
["vae"] = new object[]
{
"4",
2
}
}
},
["9"] = new()
{
ClassType = "SaveImage",
Inputs = new Dictionary<string, object?>
{
["filename_prefix"] = "SM-Inference",
["images"] = new object[]
{
"8",
0
}
}
}
};
return prompt;
}
private void OnProgressUpdateReceived(object? sender, ComfyWebSocketProgressData args)
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Max;
OutputProgress.IsIndeterminate = false;
}
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
// Decode to bitmap
using var stream = new MemoryStream(args.ImageBytes);
var bitmap = new Bitmap(stream);
ImageGalleryCardViewModel.PreviewImage = bitmap;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = true;
}
private async Task GenerateImageImpl(CancellationToken cancellationToken = default)
{
if (!ClientManager.IsConnected)
{
notificationService.Show("Client not connected", "Please connect first");
return;
}
// If enabled, randomize the seed
if (SeedCardViewModel.IsRandomizeEnabled)
{
SeedCardViewModel.GenerateNewSeed();
}
var client = ClientManager.Client;
var nodes = GetCurrentPrompt();
// Connect progress handler
OutputProgress.IsIndeterminate = true;
client.ProgressUpdateReceived += OnProgressUpdateReceived;
client.PreviewImageReceived += OnPreviewImageReceived;
try
{
var (response, promptTask) = await client.QueuePromptAsync(nodes, cancellationToken);
Logger.Info(response);
// Wait for prompt to finish
await promptTask.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {response.PromptId} finished");
// Get output images
var outputs = await client
.GetImagesForExecutedPromptAsync(response.PromptId, cancellationToken);
// Only get the SaveImage from node 9
var images = outputs["9"];
if (images is null) return;
ImageGalleryCardViewModel.ImageSources.Clear();
ImageGalleryCardViewModel.ImageSources.AddRange(
images.Select(i => i.ToUri(client.BaseAddress).ToString()));
}
finally
{
// Disconnect progress handler
OutputProgress.Value = 0;
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
client.ProgressUpdateReceived -= OnProgressUpdateReceived;
client.PreviewImageReceived -= OnPreviewImageReceived;
}
}
[RelayCommand(IncludeCancelCommand = true)]
private async Task GenerateImage(CancellationToken cancellationToken = default)
{
try
{
await GenerateImageImpl(cancellationToken);
}
catch (OperationCanceledException e)
{
Logger.Debug($"[Image Generation Canceled] {e.Message}");
}
}
/// <inheritdoc />
public void LoadState(InferenceTextToImageModel state)
{
PromptDocument.Text = state.Prompt;
NegativePromptDocument.Text = state.NegativePrompt;
SelectedModelName = state.SelectedModelName;
if (state.SeedCardState != null)
{
SeedCardViewModel.LoadState(state.SeedCardState);
}
if (state.SamplerCardState != null)
{
SamplerCardViewModel.LoadState(state.SamplerCardState);
}
}
/// <inheritdoc />
public InferenceTextToImageModel SaveState()
{
return new InferenceTextToImageModel
{
Prompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text,
SelectedModelName = SelectedModelName,
SeedCardState = SeedCardViewModel.SaveState(),
SamplerCardState = SamplerCardViewModel.SaveState()
};
}
}