Browse Source

Refactor for union latent/image node building

pull/333/head
Ionite 1 year ago
parent
commit
840f664c34
No known key found for this signature in database
  1. 43
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 31
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
  3. 51
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  4. 11
      StabilityMatrix.Core/Extensions/SizeExtensions.cs
  5. 10
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs
  6. 212
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

43
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -26,8 +26,8 @@ public static class ComfyNodeBuilderExtensions
) )
); );
builder.Connections.Latent = emptyLatent.Output; builder.Connections.Primary = emptyLatent.Output;
builder.Connections.LatentSize = new Size( builder.Connections.PrimarySize = new Size(
samplerCardViewModel.Width, samplerCardViewModel.Width,
samplerCardViewModel.Height samplerCardViewModel.Height
); );
@ -35,11 +35,11 @@ public static class ComfyNodeBuilderExtensions
// If batch index is selected, add a LatentFromBatch // If batch index is selected, add a LatentFromBatch
if (batchSizeCardViewModel.IsBatchIndexEnabled) if (batchSizeCardViewModel.IsBatchIndexEnabled)
{ {
builder.Connections.Latent = builder.Nodes builder.Connections.Primary = builder.Nodes
.AddNamedNode( .AddNamedNode(
ComfyNodeBuilder.LatentFromBatch( ComfyNodeBuilder.LatentFromBatch(
"LatentFromBatch", "LatentFromBatch",
builder.Connections.Latent, builder.GetPrimaryAsLatent(),
// remote expects a 0-based index, vm is 1-based // remote expects a 0-based index, vm is 1-based
batchSizeCardViewModel.BatchIndex - 1, batchSizeCardViewModel.BatchIndex - 1,
1 1
@ -133,12 +133,12 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Sampler not selected"),
positiveClip.Output, positiveClip.Output,
negativeClip.Output, negativeClip.Output,
builder.Connections.Latent builder.GetPrimaryAsLatent()
?? throw new ValidationException("Latent source not set"), ?? throw new ValidationException("Latent source not set"),
samplerCardViewModel.DenoiseStrength samplerCardViewModel.DenoiseStrength
) )
); );
builder.Connections.Latent = sampler.Output; builder.Connections.Primary = sampler.Output;
} }
// Add base sampler (with refiner) // Add base sampler (with refiner)
else else
@ -160,14 +160,13 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Sampler not selected"),
positiveClip.Output, positiveClip.Output,
negativeClip.Output, negativeClip.Output,
builder.Connections.Latent builder.GetPrimaryAsLatent(),
?? throw new ValidationException("Latent source not set"),
0, 0,
samplerCardViewModel.Steps, samplerCardViewModel.Steps,
true true
) )
); );
builder.Connections.Latent = sampler.Output; builder.Connections.Primary = sampler.Output;
} }
} }
@ -255,38 +254,26 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Sampler not selected"),
positiveClip.Output, positiveClip.Output,
negativeClip.Output, negativeClip.Output,
builder.Connections.Latent builder.GetPrimaryAsLatent(),
?? throw new ValidationException("Latent source not set"),
samplerCardViewModel.Steps, samplerCardViewModel.Steps,
totalSteps, totalSteps,
false false
) )
); );
builder.Connections.Latent = sampler.Output;
builder.Connections.Primary = sampler.Output;
} }
public static string SetupOutputImage(this ComfyNodeBuilder builder) public static string SetupOutputImage(this ComfyNodeBuilder builder)
{ {
// Do VAE decoding if not done already
if (builder.Connections.Image is null)
{
var vaeDecoder = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.VAEDecode(
"VAEDecode",
builder.Connections.Latent
?? throw new InvalidOperationException("Latent source not set"),
builder.Connections.GetRefinerOrBaseVAE()
)
);
builder.Connections.Image = vaeDecoder.Output;
builder.Connections.ImageSize = builder.Connections.LatentSize;
}
var previewImage = builder.Nodes.AddNamedNode( var previewImage = builder.Nodes.AddNamedNode(
new NamedComfyNode("SaveImage") new NamedComfyNode("SaveImage")
{ {
ClassType = "PreviewImage", ClassType = "PreviewImage",
Inputs = new Dictionary<string, object?> { ["images"] = builder.Connections.Image } Inputs = new Dictionary<string, object?>
{
["images"] = builder.GetPrimaryAsImage().Data
}
} }
); );

31
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs

@ -17,6 +17,7 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
@ -119,47 +120,47 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
?? throw new InvalidOperationException("Source image size is null"); ?? throw new InvalidOperationException("Source image size is null");
// Set source size // Set source size
builder.Connections.ImageSize = sourceImageSize; builder.Connections.PrimarySize = sourceImageSize;
// Load source // Load source
var loadImage = nodes.AddNamedNode( var loadImage = nodes.AddNamedNode(
ComfyNodeBuilder.LoadImage("LoadImage", sourceImageRelativePath) ComfyNodeBuilder.LoadImage("LoadImage", sourceImageRelativePath)
); );
builder.Connections.Image = loadImage.Output1; builder.Connections.Primary = loadImage.Output1;
// If upscale is enabled, add another upscale group // If upscale is enabled, add another upscale group
if (IsUpscaleEnabled) if (IsUpscaleEnabled)
{ {
var upscaleSize = builder.Connections.GetScaledImageSize(UpscalerCardViewModel.Scale); var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
// Build group // Build group
var upscaleGroup = builder.Group_UpscaleToImage( builder.Connections.Primary = builder
.Group_UpscaleToImage(
"Upscale", "Upscale",
builder.Connections.Image!, builder.GetPrimaryAsImage(),
UpscalerCardViewModel.SelectedUpscaler!.Value, UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width, upscaleSize.Width,
upscaleSize.Height upscaleSize.Height
); )
.Output;
// Set as the image output
builder.Connections.Image = upscaleGroup.Output;
} }
// If sharpen is enabled, add another sharpen group // If sharpen is enabled, add another sharpen group
if (IsSharpenEnabled) if (IsSharpenEnabled)
{ {
var sharpenGroup = nodes.AddNamedNode( builder.Connections.Primary = nodes
.AddNamedNode(
ComfyNodeBuilder.ImageSharpen( ComfyNodeBuilder.ImageSharpen(
"Sharpen", "Sharpen",
builder.Connections.Image, builder.GetPrimaryAsImage(),
SharpenCardViewModel.SharpenRadius, SharpenCardViewModel.SharpenRadius,
SharpenCardViewModel.Sigma, SharpenCardViewModel.Sigma,
SharpenCardViewModel.Alpha SharpenCardViewModel.Alpha
) )
); )
.Output;
// Set as the image output
builder.Connections.Image = sharpenGroup.Output;
} }
builder.SetupOutputImage(); builder.SetupOutputImage();

51
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.Drawing;
using System.Linq; using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
@ -8,11 +9,13 @@ using System.Threading.Tasks;
using DynamicData.Binding; using DynamicData.Binding;
using NLog; using NLog;
using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
@ -259,34 +262,25 @@ public class InferenceTextToImageViewModel
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled) if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
{ {
// Requested upscale to this size // Get new latent size
var hiresSize = builder.Connections.GetScaledLatentSize( var hiresSize = builder.Connections.PrimarySize.WithScale(
HiresUpscalerCardViewModel.Scale HiresUpscalerCardViewModel.Scale
); );
LatentNodeConnection hiresLatent;
// Select between latent upscale and normal upscale based on the upscale method // Select between latent upscale and normal upscale based on the upscale method
var selectedUpscaler = HiresUpscalerCardViewModel.SelectedUpscaler!.Value; var selectedUpscaler = HiresUpscalerCardViewModel.SelectedUpscaler!.Value;
if (selectedUpscaler.Type == ComfyUpscalerType.None) // If upscaler selected, upscale latent image first
{ if (selectedUpscaler.Type != ComfyUpscalerType.None)
// If no upscaler selected or none, just use the latent image
hiresLatent = builder.Connections.Latent!;
}
else
{ {
// Otherwise upscale the latent image builder.Connections.Primary = builder.Group_Upscale(
hiresLatent = builder
.Group_UpscaleToLatent(
"HiresFix", "HiresFix",
builder.Connections.Latent!, builder.Connections.Primary!,
builder.Connections.GetRefinerOrBaseVAE(), builder.Connections.PrimaryVAE!,
selectedUpscaler, selectedUpscaler,
hiresSize.Width, hiresSize.Width,
hiresSize.Height hiresSize.Height
) );
.Output;
} }
// Use refiner model if set, or base if not // Use refiner model if set, or base if not
@ -306,33 +300,34 @@ public class InferenceTextToImageViewModel
?? throw new ValidationException("Scheduler not selected"), ?? throw new ValidationException("Scheduler not selected"),
builder.Connections.GetRefinerOrBaseConditioning(), builder.Connections.GetRefinerOrBaseConditioning(),
builder.Connections.GetRefinerOrBaseNegativeConditioning(), builder.Connections.GetRefinerOrBaseNegativeConditioning(),
hiresLatent, builder.GetPrimaryAsLatent(),
HiresSamplerCardViewModel.DenoiseStrength HiresSamplerCardViewModel.DenoiseStrength
) )
); );
// Set as latest latent // Set as primary
builder.Connections.Latent = hiresSampler.Output; builder.Connections.Primary = hiresSampler.Output;
builder.Connections.LatentSize = hiresSize; builder.Connections.PrimarySize = hiresSize;
} }
// If upscale is enabled, add another upscale group // If upscale is enabled, add another upscale group
if (IsUpscaleEnabled) if (IsUpscaleEnabled)
{ {
var upscaleSize = builder.Connections.GetScaledLatentSize(UpscalerCardViewModel.Scale); var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
// Build group var upscaleResult = builder.Group_Upscale(
var postUpscaleGroup = builder.Group_LatentUpscaleToImage(
"PostUpscale", "PostUpscale",
builder.Connections.Latent!, builder.Connections.Primary!,
builder.Connections.GetRefinerOrBaseVAE(), builder.Connections.PrimaryVAE!,
UpscalerCardViewModel.SelectedUpscaler!.Value, UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width, upscaleSize.Width,
upscaleSize.Height upscaleSize.Height
); );
// Set as the image output builder.Connections.Primary = upscaleResult;
builder.Connections.Image = postUpscaleGroup.Output; builder.Connections.PrimarySize = upscaleSize;
} }
builder.SetupOutputImage(); builder.SetupOutputImage();

11
StabilityMatrix.Core/Extensions/SizeExtensions.cs

@ -0,0 +1,11 @@
using System.Drawing;
namespace StabilityMatrix.Core.Extensions;
public static class SizeExtensions
{
public static Size WithScale(this Size size, double scale)
{
return new Size((int)Math.Floor(size.Width * scale), (int)Math.Floor(size.Height * scale));
}
}

10
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs

@ -0,0 +1,10 @@
using OneOf;
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
/// <summary>
/// Union for the primary Image or Latent node connection
/// </summary>
[GenerateOneOf]
public partial class PrimaryNodeConnection
: OneOfBase<LatentNodeConnection, ImageNodeConnection> { }

212
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -3,7 +3,6 @@ using System.Drawing;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.Tokens;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes; namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
@ -15,10 +14,26 @@ public class ComfyNodeBuilder
{ {
public NodeDictionary Nodes { get; } = new(); public NodeDictionary Nodes { get; } = new();
public Dictionary<Type, NodeConnectionBase> GlobalConnections { get; } = new();
private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8]; private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8];
private string GetUniqueName(string nameBase)
{
var name = $"{nameBase}_1";
for (var i = 0; Nodes.ContainsKey(name); i++)
{
if (i > 1_000_000)
{
throw new InvalidOperationException(
$"Could not find unique name for base {nameBase}"
);
}
name = $"{nameBase}_{i + 1}";
}
return name;
}
public static NamedComfyNode<LatentNodeConnection> VAEEncode( public static NamedComfyNode<LatentNodeConnection> VAEEncode(
string name, string name,
ImageNodeConnection pixels, ImageNodeConnection pixels,
@ -338,7 +353,8 @@ public class ComfyNodeBuilder
VAENodeConnection vae VAENodeConnection vae
) )
{ {
return Nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output; var name = GetUniqueName("VAEDecode");
return Nodes.AddNamedNode(VAEDecode(name, latent, vae)).Output;
} }
public LatentNodeConnection Lambda_ImageToLatent( public LatentNodeConnection Lambda_ImageToLatent(
@ -346,30 +362,8 @@ public class ComfyNodeBuilder
VAENodeConnection vae VAENodeConnection vae
) )
{ {
return Nodes.AddNamedNode(VAEEncode($"{GetRandomPrefix()}_VAEEncode", pixels, vae)).Output; var name = GetUniqueName("VAEEncode");
} return Nodes.AddNamedNode(VAEEncode(name, pixels, vae)).Output;
/// <summary>
/// Get a global connection for a given type
/// </summary>
public TConnection GetConnection<TConnection>()
where TConnection : NodeConnectionBase
{
if (GlobalConnections.TryGetValue(typeof(TConnection), out var connection))
{
return (TConnection)connection;
}
throw new InvalidOperationException($"No global connection of type {typeof(TConnection)}");
}
/// <summary>
/// Set a global connection for a given type
/// </summary>
public void SetConnection<TConnection>(TConnection connection)
where TConnection : NodeConnectionBase
{
GlobalConnections[typeof(TConnection)] = connection;
} }
/// <summary> /// <summary>
@ -392,6 +386,84 @@ public class ComfyNodeBuilder
return upscaler; return upscaler;
} }
/// <summary>
/// Create a group node that scales a given image to image output
/// </summary>
public PrimaryNodeConnection Group_Upscale(
string name,
PrimaryNodeConnection primary,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return primary.Match<PrimaryNodeConnection>(
latent =>
Nodes
.AddNamedNode(
new NamedComfyNode<LatentNodeConnection>($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary<string, object?>
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
)
.Output,
image =>
Nodes
.AddNamedNode(
ImageScale(
$"{name}_ImageUpscale",
image,
upscaleInfo.Name,
height,
width,
false
)
)
.Output
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space if needed
var samplerImage = GetPrimaryAsImage(primary, vae);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
return resizedScaled.Output;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
/// <summary> /// <summary>
/// Create a group node that scales a given image to a given size /// Create a group node that scales a given image to a given size
/// </summary> /// </summary>
@ -640,6 +712,60 @@ public class ComfyNodeBuilder
return currentNode ?? throw new InvalidOperationException("No lora networks given"); return currentNode ?? throw new InvalidOperationException("No lora networks given");
} }
/// <summary>
/// Get or convert latest primary connection to latent
/// </summary>
public LatentNodeConnection GetPrimaryAsLatent()
{
if (Connections.Primary?.IsT0 == true)
{
return Connections.Primary.AsT0;
}
return GetPrimaryAsLatent(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE")
);
}
/// <summary>
/// Get or convert latest primary connection to latent
/// </summary>
public LatentNodeConnection GetPrimaryAsLatent(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae));
}
/// <summary>
/// Get or convert latest primary connection to image
/// </summary>
public ImageNodeConnection GetPrimaryAsImage()
{
if (Connections.Primary?.IsT1 == true)
{
return Connections.Primary.AsT1;
}
return GetPrimaryAsImage(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE")
);
}
/// <summary>
/// Get or convert latest primary connection to image
/// </summary>
public ImageNodeConnection GetPrimaryAsImage(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image);
}
/// <summary> /// <summary>
/// Convert to a NodeDictionary /// Convert to a NodeDictionary
/// </summary> /// </summary>
@ -666,38 +792,20 @@ public class ComfyNodeBuilder
public ConditioningNodeConnection? RefinerConditioning { get; set; } public ConditioningNodeConnection? RefinerConditioning { get; set; }
public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; } public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; }
public LatentNodeConnection? Latent { get; set; } public PrimaryNodeConnection? Primary { get; set; }
public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; }
/*public LatentNodeConnection? Latent { get; set; }
public Size LatentSize { get; set; } public Size LatentSize { get; set; }
public ImageNodeConnection? Image { get; set; } public ImageNodeConnection? Image { get; set; }
public Size ImageSize { get; set; } public Size ImageSize { get; set; }*/
public List<NamedComfyNode> OutputNodes { get; } = new(); public List<NamedComfyNode> OutputNodes { get; } = new();
public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name); public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name);
/// <summary>
/// Gets the latent size scaled by a given factor
/// </summary>
public Size GetScaledLatentSize(double scale)
{
return new Size(
(int)Math.Floor(LatentSize.Width * scale),
(int)Math.Floor(LatentSize.Height * scale)
);
}
/// <summary>
/// Gets the image size scaled by a given factor
/// </summary>
public Size GetScaledImageSize(double scale)
{
return new Size(
(int)Math.Floor(ImageSize.Width * scale),
(int)Math.Floor(ImageSize.Height * scale)
);
}
public VAENodeConnection GetRefinerOrBaseVAE() public VAENodeConnection GetRefinerOrBaseVAE()
{ {
return RefinerVAE ?? BaseVAE ?? throw new NullReferenceException("No VAE"); return RefinerVAE ?? BaseVAE ?? throw new NullReferenceException("No VAE");

Loading…
Cancel
Save