From 840f664c34fc64c615fb18f9e95ab0fcc850f8c7 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sat, 14 Oct 2023 17:47:15 -0400 Subject: [PATCH] Refactor for union latent/image node building --- .../Extensions/ComfyNodeBuilderExtensions.cs | 43 ++-- .../InferenceImageUpscaleViewModel.cs | 49 ++-- .../InferenceTextToImageViewModel.cs | 59 +++-- .../Extensions/SizeExtensions.cs | 11 + .../Comfy/NodeTypes/PrimaryNodeConnection.cs | 10 + .../Api/Comfy/Nodes/ComfyNodeBuilder.cs | 212 +++++++++++++----- 6 files changed, 248 insertions(+), 136 deletions(-) create mode 100644 StabilityMatrix.Core/Extensions/SizeExtensions.cs create mode 100644 StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs index e04f4f46..322b45c1 100644 --- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs +++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs @@ -26,8 +26,8 @@ public static class ComfyNodeBuilderExtensions ) ); - builder.Connections.Latent = emptyLatent.Output; - builder.Connections.LatentSize = new Size( + builder.Connections.Primary = emptyLatent.Output; + builder.Connections.PrimarySize = new Size( samplerCardViewModel.Width, samplerCardViewModel.Height ); @@ -35,11 +35,11 @@ public static class ComfyNodeBuilderExtensions // If batch index is selected, add a LatentFromBatch if (batchSizeCardViewModel.IsBatchIndexEnabled) { - builder.Connections.Latent = builder.Nodes + builder.Connections.Primary = builder.Nodes .AddNamedNode( ComfyNodeBuilder.LatentFromBatch( "LatentFromBatch", - builder.Connections.Latent, + builder.GetPrimaryAsLatent(), // remote expects a 0-based index, vm is 1-based batchSizeCardViewModel.BatchIndex - 1, 1 @@ -133,12 +133,12 @@ public static class ComfyNodeBuilderExtensions ?? throw new ValidationException("Sampler not selected"), positiveClip.Output, negativeClip.Output, - builder.Connections.Latent + builder.GetPrimaryAsLatent() ?? throw new ValidationException("Latent source not set"), samplerCardViewModel.DenoiseStrength ) ); - builder.Connections.Latent = sampler.Output; + builder.Connections.Primary = sampler.Output; } // Add base sampler (with refiner) else @@ -160,14 +160,13 @@ public static class ComfyNodeBuilderExtensions ?? throw new ValidationException("Sampler not selected"), positiveClip.Output, negativeClip.Output, - builder.Connections.Latent - ?? throw new ValidationException("Latent source not set"), + builder.GetPrimaryAsLatent(), 0, samplerCardViewModel.Steps, true ) ); - builder.Connections.Latent = sampler.Output; + builder.Connections.Primary = sampler.Output; } } @@ -255,38 +254,26 @@ public static class ComfyNodeBuilderExtensions ?? throw new ValidationException("Sampler not selected"), positiveClip.Output, negativeClip.Output, - builder.Connections.Latent - ?? throw new ValidationException("Latent source not set"), + builder.GetPrimaryAsLatent(), samplerCardViewModel.Steps, totalSteps, false ) ); - builder.Connections.Latent = sampler.Output; + + builder.Connections.Primary = sampler.Output; } public static string SetupOutputImage(this ComfyNodeBuilder builder) { - // Do VAE decoding if not done already - if (builder.Connections.Image is null) - { - var vaeDecoder = builder.Nodes.AddNamedNode( - ComfyNodeBuilder.VAEDecode( - "VAEDecode", - builder.Connections.Latent - ?? throw new InvalidOperationException("Latent source not set"), - builder.Connections.GetRefinerOrBaseVAE() - ) - ); - builder.Connections.Image = vaeDecoder.Output; - builder.Connections.ImageSize = builder.Connections.LatentSize; - } - var previewImage = builder.Nodes.AddNamedNode( new NamedComfyNode("SaveImage") { ClassType = "PreviewImage", - Inputs = new Dictionary { ["images"] = builder.Connections.Image } + Inputs = new Dictionary + { + ["images"] = builder.GetPrimaryAsImage().Data + } } ); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs index d97cf780..42ab1177 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs @@ -17,6 +17,7 @@ using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Services; @@ -119,47 +120,47 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase ?? throw new InvalidOperationException("Source image size is null"); // Set source size - builder.Connections.ImageSize = sourceImageSize; + builder.Connections.PrimarySize = sourceImageSize; // Load source var loadImage = nodes.AddNamedNode( ComfyNodeBuilder.LoadImage("LoadImage", sourceImageRelativePath) ); - builder.Connections.Image = loadImage.Output1; + builder.Connections.Primary = loadImage.Output1; // If upscale is enabled, add another upscale group if (IsUpscaleEnabled) { - var upscaleSize = builder.Connections.GetScaledImageSize(UpscalerCardViewModel.Scale); - - // Build group - var upscaleGroup = builder.Group_UpscaleToImage( - "Upscale", - builder.Connections.Image!, - UpscalerCardViewModel.SelectedUpscaler!.Value, - upscaleSize.Width, - upscaleSize.Height + var upscaleSize = builder.Connections.PrimarySize.WithScale( + UpscalerCardViewModel.Scale ); - // Set as the image output - builder.Connections.Image = upscaleGroup.Output; + // Build group + builder.Connections.Primary = builder + .Group_UpscaleToImage( + "Upscale", + builder.GetPrimaryAsImage(), + UpscalerCardViewModel.SelectedUpscaler!.Value, + upscaleSize.Width, + upscaleSize.Height + ) + .Output; } // If sharpen is enabled, add another sharpen group if (IsSharpenEnabled) { - var sharpenGroup = nodes.AddNamedNode( - ComfyNodeBuilder.ImageSharpen( - "Sharpen", - builder.Connections.Image, - SharpenCardViewModel.SharpenRadius, - SharpenCardViewModel.Sigma, - SharpenCardViewModel.Alpha + builder.Connections.Primary = nodes + .AddNamedNode( + ComfyNodeBuilder.ImageSharpen( + "Sharpen", + builder.GetPrimaryAsImage(), + SharpenCardViewModel.SharpenRadius, + SharpenCardViewModel.Sigma, + SharpenCardViewModel.Alpha + ) ) - ); - - // Set as the image output - builder.Connections.Image = sharpenGroup.Output; + .Output; } builder.SetupOutputImage(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 929aa771..8f0e9389 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; +using System.Drawing; using System.Linq; using System.Text.Json.Serialization; using System.Threading; @@ -8,11 +9,13 @@ using System.Threading.Tasks; using DynamicData.Binding; using NLog; using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; @@ -259,34 +262,25 @@ public class InferenceTextToImageViewModel // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled) { - // Requested upscale to this size - var hiresSize = builder.Connections.GetScaledLatentSize( + // Get new latent size + var hiresSize = builder.Connections.PrimarySize.WithScale( HiresUpscalerCardViewModel.Scale ); - LatentNodeConnection hiresLatent; - // Select between latent upscale and normal upscale based on the upscale method var selectedUpscaler = HiresUpscalerCardViewModel.SelectedUpscaler!.Value; - if (selectedUpscaler.Type == ComfyUpscalerType.None) - { - // If no upscaler selected or none, just use the latent image - hiresLatent = builder.Connections.Latent!; - } - else + // If upscaler selected, upscale latent image first + if (selectedUpscaler.Type != ComfyUpscalerType.None) { - // Otherwise upscale the latent image - hiresLatent = builder - .Group_UpscaleToLatent( - "HiresFix", - builder.Connections.Latent!, - builder.Connections.GetRefinerOrBaseVAE(), - selectedUpscaler, - hiresSize.Width, - hiresSize.Height - ) - .Output; + builder.Connections.Primary = builder.Group_Upscale( + "HiresFix", + builder.Connections.Primary!, + builder.Connections.PrimaryVAE!, + selectedUpscaler, + hiresSize.Width, + hiresSize.Height + ); } // Use refiner model if set, or base if not @@ -306,33 +300,34 @@ public class InferenceTextToImageViewModel ?? throw new ValidationException("Scheduler not selected"), builder.Connections.GetRefinerOrBaseConditioning(), builder.Connections.GetRefinerOrBaseNegativeConditioning(), - hiresLatent, + builder.GetPrimaryAsLatent(), HiresSamplerCardViewModel.DenoiseStrength ) ); - // Set as latest latent - builder.Connections.Latent = hiresSampler.Output; - builder.Connections.LatentSize = hiresSize; + // Set as primary + builder.Connections.Primary = hiresSampler.Output; + builder.Connections.PrimarySize = hiresSize; } // If upscale is enabled, add another upscale group if (IsUpscaleEnabled) { - var upscaleSize = builder.Connections.GetScaledLatentSize(UpscalerCardViewModel.Scale); + var upscaleSize = builder.Connections.PrimarySize.WithScale( + UpscalerCardViewModel.Scale + ); - // Build group - var postUpscaleGroup = builder.Group_LatentUpscaleToImage( + var upscaleResult = builder.Group_Upscale( "PostUpscale", - builder.Connections.Latent!, - builder.Connections.GetRefinerOrBaseVAE(), + builder.Connections.Primary!, + builder.Connections.PrimaryVAE!, UpscalerCardViewModel.SelectedUpscaler!.Value, upscaleSize.Width, upscaleSize.Height ); - // Set as the image output - builder.Connections.Image = postUpscaleGroup.Output; + builder.Connections.Primary = upscaleResult; + builder.Connections.PrimarySize = upscaleSize; } builder.SetupOutputImage(); diff --git a/StabilityMatrix.Core/Extensions/SizeExtensions.cs b/StabilityMatrix.Core/Extensions/SizeExtensions.cs new file mode 100644 index 00000000..05f190e3 --- /dev/null +++ b/StabilityMatrix.Core/Extensions/SizeExtensions.cs @@ -0,0 +1,11 @@ +using System.Drawing; + +namespace StabilityMatrix.Core.Extensions; + +public static class SizeExtensions +{ + public static Size WithScale(this Size size, double scale) + { + return new Size((int)Math.Floor(size.Width * scale), (int)Math.Floor(size.Height * scale)); + } +} diff --git a/StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs b/StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs new file mode 100644 index 00000000..6c700c61 --- /dev/null +++ b/StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs @@ -0,0 +1,10 @@ +using OneOf; + +namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; + +/// +/// Union for the primary Image or Latent node connection +/// +[GenerateOneOf] +public partial class PrimaryNodeConnection + : OneOfBase { } diff --git a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs index 9e76fa26..0cbf37fe 100644 --- a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs +++ b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs @@ -3,7 +3,6 @@ using System.Drawing; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Database; -using StabilityMatrix.Core.Models.Tokens; namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes; @@ -15,10 +14,26 @@ public class ComfyNodeBuilder { public NodeDictionary Nodes { get; } = new(); - public Dictionary GlobalConnections { get; } = new(); - private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8]; + private string GetUniqueName(string nameBase) + { + var name = $"{nameBase}_1"; + for (var i = 0; Nodes.ContainsKey(name); i++) + { + if (i > 1_000_000) + { + throw new InvalidOperationException( + $"Could not find unique name for base {nameBase}" + ); + } + + name = $"{nameBase}_{i + 1}"; + } + + return name; + } + public static NamedComfyNode VAEEncode( string name, ImageNodeConnection pixels, @@ -338,7 +353,8 @@ public class ComfyNodeBuilder VAENodeConnection vae ) { - return Nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output; + var name = GetUniqueName("VAEDecode"); + return Nodes.AddNamedNode(VAEDecode(name, latent, vae)).Output; } public LatentNodeConnection Lambda_ImageToLatent( @@ -346,30 +362,8 @@ public class ComfyNodeBuilder VAENodeConnection vae ) { - return Nodes.AddNamedNode(VAEEncode($"{GetRandomPrefix()}_VAEEncode", pixels, vae)).Output; - } - - /// - /// Get a global connection for a given type - /// - public TConnection GetConnection() - where TConnection : NodeConnectionBase - { - if (GlobalConnections.TryGetValue(typeof(TConnection), out var connection)) - { - return (TConnection)connection; - } - - throw new InvalidOperationException($"No global connection of type {typeof(TConnection)}"); - } - - /// - /// Set a global connection for a given type - /// - public void SetConnection(TConnection connection) - where TConnection : NodeConnectionBase - { - GlobalConnections[typeof(TConnection)] = connection; + var name = GetUniqueName("VAEEncode"); + return Nodes.AddNamedNode(VAEEncode(name, pixels, vae)).Output; } /// @@ -392,6 +386,84 @@ public class ComfyNodeBuilder return upscaler; } + /// + /// Create a group node that scales a given image to image output + /// + public PrimaryNodeConnection Group_Upscale( + string name, + PrimaryNodeConnection primary, + VAENodeConnection vae, + ComfyUpscaler upscaleInfo, + int width, + int height + ) + { + if (upscaleInfo.Type == ComfyUpscalerType.Latent) + { + return primary.Match( + latent => + Nodes + .AddNamedNode( + new NamedComfyNode($"{name}_LatentUpscale") + { + ClassType = "LatentUpscale", + Inputs = new Dictionary + { + ["upscale_method"] = upscaleInfo.Name, + ["width"] = width, + ["height"] = height, + ["crop"] = "disabled", + ["samples"] = latent.Data, + } + } + ) + .Output, + image => + Nodes + .AddNamedNode( + ImageScale( + $"{name}_ImageUpscale", + image, + upscaleInfo.Name, + height, + width, + false + ) + ) + .Output + ); + } + + if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN) + { + // Convert to image space if needed + var samplerImage = GetPrimaryAsImage(primary, vae); + + // Do group upscale + var modelUpscaler = Group_UpscaleWithModel( + $"{name}_ModelUpscale", + upscaleInfo.Name, + samplerImage + ); + + // Since the model upscale is fixed to model (2x/4x), scale it again to the requested size + var resizedScaled = Nodes.AddNamedNode( + ImageScale( + $"{name}_ImageScale", + modelUpscaler.Output, + "bilinear", + height, + width, + false + ) + ); + + return resizedScaled.Output; + } + + throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}"); + } + /// /// Create a group node that scales a given image to a given size /// @@ -640,6 +712,60 @@ public class ComfyNodeBuilder return currentNode ?? throw new InvalidOperationException("No lora networks given"); } + /// + /// Get or convert latest primary connection to latent + /// + public LatentNodeConnection GetPrimaryAsLatent() + { + if (Connections.Primary?.IsT0 == true) + { + return Connections.Primary.AsT0; + } + + return GetPrimaryAsLatent( + Connections.Primary ?? throw new NullReferenceException("No primary connection"), + Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE") + ); + } + + /// + /// Get or convert latest primary connection to latent + /// + public LatentNodeConnection GetPrimaryAsLatent( + PrimaryNodeConnection primary, + VAENodeConnection vae + ) + { + return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae)); + } + + /// + /// Get or convert latest primary connection to image + /// + public ImageNodeConnection GetPrimaryAsImage() + { + if (Connections.Primary?.IsT1 == true) + { + return Connections.Primary.AsT1; + } + + return GetPrimaryAsImage( + Connections.Primary ?? throw new NullReferenceException("No primary connection"), + Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE") + ); + } + + /// + /// Get or convert latest primary connection to image + /// + public ImageNodeConnection GetPrimaryAsImage( + PrimaryNodeConnection primary, + VAENodeConnection vae + ) + { + return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image); + } + /// /// Convert to a NodeDictionary /// @@ -666,38 +792,20 @@ public class ComfyNodeBuilder public ConditioningNodeConnection? RefinerConditioning { get; set; } public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; } - public LatentNodeConnection? Latent { get; set; } + public PrimaryNodeConnection? Primary { get; set; } + public VAENodeConnection? PrimaryVAE { get; set; } + public Size PrimarySize { get; set; } + + /*public LatentNodeConnection? Latent { get; set; } public Size LatentSize { get; set; } public ImageNodeConnection? Image { get; set; } - public Size ImageSize { get; set; } + public Size ImageSize { get; set; }*/ public List OutputNodes { get; } = new(); public IEnumerable OutputNodeNames => OutputNodes.Select(n => n.Name); - /// - /// Gets the latent size scaled by a given factor - /// - public Size GetScaledLatentSize(double scale) - { - return new Size( - (int)Math.Floor(LatentSize.Width * scale), - (int)Math.Floor(LatentSize.Height * scale) - ); - } - - /// - /// Gets the image size scaled by a given factor - /// - public Size GetScaledImageSize(double scale) - { - return new Size( - (int)Math.Floor(ImageSize.Width * scale), - (int)Math.Floor(ImageSize.Height * scale) - ); - } - public VAENodeConnection GetRefinerOrBaseVAE() { return RefinerVAE ?? BaseVAE ?? throw new NullReferenceException("No VAE");