Browse Source

Refactor for union latent/image node building

pull/333/head
Ionite 1 year ago
parent
commit
840f664c34
No known key found for this signature in database
  1. 43
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 31
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
  3. 51
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  4. 11
      StabilityMatrix.Core/Extensions/SizeExtensions.cs
  5. 10
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs
  6. 212
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

43
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -26,8 +26,8 @@ public static class ComfyNodeBuilderExtensions
)
);
builder.Connections.Latent = emptyLatent.Output;
builder.Connections.LatentSize = new Size(
builder.Connections.Primary = emptyLatent.Output;
builder.Connections.PrimarySize = new Size(
samplerCardViewModel.Width,
samplerCardViewModel.Height
);
@ -35,11 +35,11 @@ public static class ComfyNodeBuilderExtensions
// If batch index is selected, add a LatentFromBatch
if (batchSizeCardViewModel.IsBatchIndexEnabled)
{
builder.Connections.Latent = builder.Nodes
builder.Connections.Primary = builder.Nodes
.AddNamedNode(
ComfyNodeBuilder.LatentFromBatch(
"LatentFromBatch",
builder.Connections.Latent,
builder.GetPrimaryAsLatent(),
// remote expects a 0-based index, vm is 1-based
batchSizeCardViewModel.BatchIndex - 1,
1
@ -133,12 +133,12 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"),
positiveClip.Output,
negativeClip.Output,
builder.Connections.Latent
builder.GetPrimaryAsLatent()
?? throw new ValidationException("Latent source not set"),
samplerCardViewModel.DenoiseStrength
)
);
builder.Connections.Latent = sampler.Output;
builder.Connections.Primary = sampler.Output;
}
// Add base sampler (with refiner)
else
@ -160,14 +160,13 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"),
positiveClip.Output,
negativeClip.Output,
builder.Connections.Latent
?? throw new ValidationException("Latent source not set"),
builder.GetPrimaryAsLatent(),
0,
samplerCardViewModel.Steps,
true
)
);
builder.Connections.Latent = sampler.Output;
builder.Connections.Primary = sampler.Output;
}
}
@ -255,38 +254,26 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"),
positiveClip.Output,
negativeClip.Output,
builder.Connections.Latent
?? throw new ValidationException("Latent source not set"),
builder.GetPrimaryAsLatent(),
samplerCardViewModel.Steps,
totalSteps,
false
)
);
builder.Connections.Latent = sampler.Output;
builder.Connections.Primary = sampler.Output;
}
public static string SetupOutputImage(this ComfyNodeBuilder builder)
{
// Do VAE decoding if not done already
if (builder.Connections.Image is null)
{
var vaeDecoder = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.VAEDecode(
"VAEDecode",
builder.Connections.Latent
?? throw new InvalidOperationException("Latent source not set"),
builder.Connections.GetRefinerOrBaseVAE()
)
);
builder.Connections.Image = vaeDecoder.Output;
builder.Connections.ImageSize = builder.Connections.LatentSize;
}
var previewImage = builder.Nodes.AddNamedNode(
new NamedComfyNode("SaveImage")
{
ClassType = "PreviewImage",
Inputs = new Dictionary<string, object?> { ["images"] = builder.Connections.Image }
Inputs = new Dictionary<string, object?>
{
["images"] = builder.GetPrimaryAsImage().Data
}
}
);

31
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs

@ -17,6 +17,7 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
@ -119,47 +120,47 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
?? throw new InvalidOperationException("Source image size is null");
// Set source size
builder.Connections.ImageSize = sourceImageSize;
builder.Connections.PrimarySize = sourceImageSize;
// Load source
var loadImage = nodes.AddNamedNode(
ComfyNodeBuilder.LoadImage("LoadImage", sourceImageRelativePath)
);
builder.Connections.Image = loadImage.Output1;
builder.Connections.Primary = loadImage.Output1;
// If upscale is enabled, add another upscale group
if (IsUpscaleEnabled)
{
var upscaleSize = builder.Connections.GetScaledImageSize(UpscalerCardViewModel.Scale);
var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
// Build group
var upscaleGroup = builder.Group_UpscaleToImage(
builder.Connections.Primary = builder
.Group_UpscaleToImage(
"Upscale",
builder.Connections.Image!,
builder.GetPrimaryAsImage(),
UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width,
upscaleSize.Height
);
// Set as the image output
builder.Connections.Image = upscaleGroup.Output;
)
.Output;
}
// If sharpen is enabled, add another sharpen group
if (IsSharpenEnabled)
{
var sharpenGroup = nodes.AddNamedNode(
builder.Connections.Primary = nodes
.AddNamedNode(
ComfyNodeBuilder.ImageSharpen(
"Sharpen",
builder.Connections.Image,
builder.GetPrimaryAsImage(),
SharpenCardViewModel.SharpenRadius,
SharpenCardViewModel.Sigma,
SharpenCardViewModel.Alpha
)
);
// Set as the image output
builder.Connections.Image = sharpenGroup.Output;
)
.Output;
}
builder.SetupOutputImage();

51
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Drawing;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
@ -8,11 +9,13 @@ using System.Threading.Tasks;
using DynamicData.Binding;
using NLog;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
@ -259,34 +262,25 @@ public class InferenceTextToImageViewModel
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
{
// Requested upscale to this size
var hiresSize = builder.Connections.GetScaledLatentSize(
// Get new latent size
var hiresSize = builder.Connections.PrimarySize.WithScale(
HiresUpscalerCardViewModel.Scale
);
LatentNodeConnection hiresLatent;
// Select between latent upscale and normal upscale based on the upscale method
var selectedUpscaler = HiresUpscalerCardViewModel.SelectedUpscaler!.Value;
if (selectedUpscaler.Type == ComfyUpscalerType.None)
{
// If no upscaler selected or none, just use the latent image
hiresLatent = builder.Connections.Latent!;
}
else
// If upscaler selected, upscale latent image first
if (selectedUpscaler.Type != ComfyUpscalerType.None)
{
// Otherwise upscale the latent image
hiresLatent = builder
.Group_UpscaleToLatent(
builder.Connections.Primary = builder.Group_Upscale(
"HiresFix",
builder.Connections.Latent!,
builder.Connections.GetRefinerOrBaseVAE(),
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
selectedUpscaler,
hiresSize.Width,
hiresSize.Height
)
.Output;
);
}
// Use refiner model if set, or base if not
@ -306,33 +300,34 @@ public class InferenceTextToImageViewModel
?? throw new ValidationException("Scheduler not selected"),
builder.Connections.GetRefinerOrBaseConditioning(),
builder.Connections.GetRefinerOrBaseNegativeConditioning(),
hiresLatent,
builder.GetPrimaryAsLatent(),
HiresSamplerCardViewModel.DenoiseStrength
)
);
// Set as latest latent
builder.Connections.Latent = hiresSampler.Output;
builder.Connections.LatentSize = hiresSize;
// Set as primary
builder.Connections.Primary = hiresSampler.Output;
builder.Connections.PrimarySize = hiresSize;
}
// If upscale is enabled, add another upscale group
if (IsUpscaleEnabled)
{
var upscaleSize = builder.Connections.GetScaledLatentSize(UpscalerCardViewModel.Scale);
var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
// Build group
var postUpscaleGroup = builder.Group_LatentUpscaleToImage(
var upscaleResult = builder.Group_Upscale(
"PostUpscale",
builder.Connections.Latent!,
builder.Connections.GetRefinerOrBaseVAE(),
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width,
upscaleSize.Height
);
// Set as the image output
builder.Connections.Image = postUpscaleGroup.Output;
builder.Connections.Primary = upscaleResult;
builder.Connections.PrimarySize = upscaleSize;
}
builder.SetupOutputImage();

11
StabilityMatrix.Core/Extensions/SizeExtensions.cs

@ -0,0 +1,11 @@
using System.Drawing;
namespace StabilityMatrix.Core.Extensions;
public static class SizeExtensions
{
public static Size WithScale(this Size size, double scale)
{
return new Size((int)Math.Floor(size.Width * scale), (int)Math.Floor(size.Height * scale));
}
}

10
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/PrimaryNodeConnection.cs

@ -0,0 +1,10 @@
using OneOf;
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
/// <summary>
/// Union for the primary Image or Latent node connection
/// </summary>
[GenerateOneOf]
public partial class PrimaryNodeConnection
: OneOfBase<LatentNodeConnection, ImageNodeConnection> { }

212
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -3,7 +3,6 @@ using System.Drawing;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.Tokens;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
@ -15,10 +14,26 @@ public class ComfyNodeBuilder
{
public NodeDictionary Nodes { get; } = new();
public Dictionary<Type, NodeConnectionBase> GlobalConnections { get; } = new();
private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8];
private string GetUniqueName(string nameBase)
{
var name = $"{nameBase}_1";
for (var i = 0; Nodes.ContainsKey(name); i++)
{
if (i > 1_000_000)
{
throw new InvalidOperationException(
$"Could not find unique name for base {nameBase}"
);
}
name = $"{nameBase}_{i + 1}";
}
return name;
}
public static NamedComfyNode<LatentNodeConnection> VAEEncode(
string name,
ImageNodeConnection pixels,
@ -338,7 +353,8 @@ public class ComfyNodeBuilder
VAENodeConnection vae
)
{
return Nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output;
var name = GetUniqueName("VAEDecode");
return Nodes.AddNamedNode(VAEDecode(name, latent, vae)).Output;
}
public LatentNodeConnection Lambda_ImageToLatent(
@ -346,30 +362,8 @@ public class ComfyNodeBuilder
VAENodeConnection vae
)
{
return Nodes.AddNamedNode(VAEEncode($"{GetRandomPrefix()}_VAEEncode", pixels, vae)).Output;
}
/// <summary>
/// Get a global connection for a given type
/// </summary>
public TConnection GetConnection<TConnection>()
where TConnection : NodeConnectionBase
{
if (GlobalConnections.TryGetValue(typeof(TConnection), out var connection))
{
return (TConnection)connection;
}
throw new InvalidOperationException($"No global connection of type {typeof(TConnection)}");
}
/// <summary>
/// Set a global connection for a given type
/// </summary>
public void SetConnection<TConnection>(TConnection connection)
where TConnection : NodeConnectionBase
{
GlobalConnections[typeof(TConnection)] = connection;
var name = GetUniqueName("VAEEncode");
return Nodes.AddNamedNode(VAEEncode(name, pixels, vae)).Output;
}
/// <summary>
@ -392,6 +386,84 @@ public class ComfyNodeBuilder
return upscaler;
}
/// <summary>
/// Create a group node that scales a given image to image output
/// </summary>
public PrimaryNodeConnection Group_Upscale(
string name,
PrimaryNodeConnection primary,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return primary.Match<PrimaryNodeConnection>(
latent =>
Nodes
.AddNamedNode(
new NamedComfyNode<LatentNodeConnection>($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary<string, object?>
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
)
.Output,
image =>
Nodes
.AddNamedNode(
ImageScale(
$"{name}_ImageUpscale",
image,
upscaleInfo.Name,
height,
width,
false
)
)
.Output
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space if needed
var samplerImage = GetPrimaryAsImage(primary, vae);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
return resizedScaled.Output;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
/// <summary>
/// Create a group node that scales a given image to a given size
/// </summary>
@ -640,6 +712,60 @@ public class ComfyNodeBuilder
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
/// <summary>
/// Get or convert latest primary connection to latent
/// </summary>
public LatentNodeConnection GetPrimaryAsLatent()
{
if (Connections.Primary?.IsT0 == true)
{
return Connections.Primary.AsT0;
}
return GetPrimaryAsLatent(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE")
);
}
/// <summary>
/// Get or convert latest primary connection to latent
/// </summary>
public LatentNodeConnection GetPrimaryAsLatent(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae));
}
/// <summary>
/// Get or convert latest primary connection to image
/// </summary>
public ImageNodeConnection GetPrimaryAsImage()
{
if (Connections.Primary?.IsT1 == true)
{
return Connections.Primary.AsT1;
}
return GetPrimaryAsImage(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE")
);
}
/// <summary>
/// Get or convert latest primary connection to image
/// </summary>
public ImageNodeConnection GetPrimaryAsImage(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image);
}
/// <summary>
/// Convert to a NodeDictionary
/// </summary>
@ -666,38 +792,20 @@ public class ComfyNodeBuilder
public ConditioningNodeConnection? RefinerConditioning { get; set; }
public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; }
public LatentNodeConnection? Latent { get; set; }
public PrimaryNodeConnection? Primary { get; set; }
public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; }
/*public LatentNodeConnection? Latent { get; set; }
public Size LatentSize { get; set; }
public ImageNodeConnection? Image { get; set; }
public Size ImageSize { get; set; }
public Size ImageSize { get; set; }*/
public List<NamedComfyNode> OutputNodes { get; } = new();
public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name);
/// <summary>
/// Gets the latent size scaled by a given factor
/// </summary>
public Size GetScaledLatentSize(double scale)
{
return new Size(
(int)Math.Floor(LatentSize.Width * scale),
(int)Math.Floor(LatentSize.Height * scale)
);
}
/// <summary>
/// Gets the image size scaled by a given factor
/// </summary>
public Size GetScaledImageSize(double scale)
{
return new Size(
(int)Math.Floor(ImageSize.Width * scale),
(int)Math.Floor(ImageSize.Height * scale)
);
}
public VAENodeConnection GetRefinerOrBaseVAE()
{
return RefinerVAE ?? BaseVAE ?? throw new NullReferenceException("No VAE");

Loading…
Cancel
Save