using System.Diagnostics.CodeAnalysis;
using System.Drawing;
using System.Runtime.Serialization;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
///
/// Builder functions for comfy nodes
///
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public class ComfyNodeBuilder
{
public NodeDictionary Nodes { get; } = new();
private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8];
private string GetUniqueName(string nameBase)
{
var name = $"{nameBase}_1";
for (var i = 0; Nodes.ContainsKey(name); i++)
{
if (i > 1_000_000)
{
throw new InvalidOperationException(
$"Could not find unique name for base {nameBase}"
);
}
name = $"{nameBase}_{i + 1}";
}
return name;
}
public record VAEEncode : ComfyTypedNodeBase
{
public required ImageNodeConnection Pixels { get; init; }
public required VAENodeConnection Vae { get; init; }
}
public record VAEDecode : ComfyTypedNodeBase
{
public required LatentNodeConnection Samples { get; init; }
public required VAENodeConnection Vae { get; init; }
}
public record KSampler : ComfyTypedNodeBase
{
public required ModelNodeConnection Model { get; init; }
public required ulong Seed { get; init; }
public required int Steps { get; init; }
public required double Cfg { get; init; }
public required string SamplerName { get; init; }
public required string Scheduler { get; init; }
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required LatentNodeConnection LatentImage { get; init; }
public required double Denoise { get; init; }
}
/*public static NamedComfyNode KSampler(
string name,
ModelNodeConnection model,
ulong seed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
double denoise
)
{
return new NamedComfyNode(name)
{
ClassType = "KSampler",
Inputs = new Dictionary
{
["model"] = model.Data,
["seed"] = seed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["denoise"] = denoise
}
};
}*/
public record KSamplerAdvanced : ComfyTypedNodeBase
{
public required ModelNodeConnection Model { get; init; }
[BoolStringMember("enable", "disable")]
public required bool AddNoise { get; init; }
public required ulong NoiseSeed { get; init; }
public required int Steps { get; init; }
public required double Cfg { get; init; }
public required string SamplerName { get; init; }
public required string Scheduler { get; init; }
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required LatentNodeConnection LatentImage { get; init; }
public required int StartAtStep { get; init; }
public required int EndAtStep { get; init; }
[BoolStringMember("enable", "disable")]
public bool ReturnWithLeftoverNoise { get; init; }
}
/*public static NamedComfyNode KSamplerAdvanced(
string name,
ModelNodeConnection model,
bool addNoise,
ulong noiseSeed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
int startAtStep,
int endAtStep,
bool returnWithLeftoverNoise
)
{
return new NamedComfyNode(name)
{
ClassType = "KSamplerAdvanced",
Inputs = new Dictionary
{
["model"] = model.Data,
["add_noise"] = addNoise ? "enable" : "disable",
["noise_seed"] = noiseSeed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["start_at_step"] = startAtStep,
["end_at_step"] = endAtStep,
["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable"
}
};
}*/
public record EmptyLatentImage : ComfyTypedNodeBase
{
public required int BatchSize { get; init; }
public required int Height { get; init; }
public required int Width { get; init; }
}
public static NamedComfyNode LatentFromBatch(
string name,
LatentNodeConnection samples,
int batchIndex,
int length
)
{
return new NamedComfyNode(name)
{
ClassType = "LatentFromBatch",
Inputs = new Dictionary
{
["samples"] = samples.Data,
["batch_index"] = batchIndex,
["length"] = length,
}
};
}
public static NamedComfyNode ImageUpscaleWithModel(
string name,
UpscaleModelNodeConnection upscaleModel,
ImageNodeConnection image
)
{
return new NamedComfyNode(name)
{
ClassType = "ImageUpscaleWithModel",
Inputs = new Dictionary
{
["upscale_model"] = upscaleModel.Data,
["image"] = image.Data
}
};
}
public static NamedComfyNode UpscaleModelLoader(
string name,
string modelName
)
{
return new NamedComfyNode(name)
{
ClassType = "UpscaleModelLoader",
Inputs = new Dictionary { ["model_name"] = modelName }
};
}
public static NamedComfyNode ImageScale(
string name,
ImageNodeConnection image,
string method,
int height,
int width,
bool crop
)
{
return new NamedComfyNode(name)
{
ClassType = "ImageScale",
Inputs = new Dictionary
{
["image"] = image.Data,
["upscale_method"] = method,
["height"] = height,
["width"] = width,
["crop"] = crop ? "center" : "disabled"
}
};
}
public record VAELoader : ComfyTypedNodeBase
{
public required string VaeName { get; init; }
}
public static NamedComfyNode LoraLoader(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
string loraName,
double strengthModel,
double strengthClip
)
{
return new NamedComfyNode(name)
{
ClassType = "LoraLoader",
Inputs = new Dictionary
{
["model"] = model.Data,
["clip"] = clip.Data,
["lora_name"] = loraName,
["strength_model"] = strengthModel,
["strength_clip"] = strengthClip
}
};
}
public record CheckpointLoaderSimple
: ComfyTypedNodeBase
{
public required string CkptName { get; init; }
}
public record FreeU : ComfyTypedNodeBase
{
public required ModelNodeConnection Model { get; init; }
public required double B1 { get; init; }
public required double B2 { get; init; }
public required double S1 { get; init; }
public required double S2 { get; init; }
}
[SuppressMessage("ReSharper", "InconsistentNaming")]
public record CLIPTextEncode : ComfyTypedNodeBase
{
public required ClipNodeConnection Clip { get; init; }
public required string Text { get; init; }
}
public static NamedComfyNode ClipTextEncode(
string name,
ClipNodeConnection clip,
string text
)
{
return new NamedComfyNode(name)
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary { ["clip"] = clip.Data, ["text"] = text }
};
}
public record LoadImage : ComfyTypedNodeBase
{
///
/// Path relative to the Comfy input directory
///
public required string Image { get; init; }
}
public record PreviewImage : ComfyTypedNodeBase
{
public required ImageNodeConnection Images { get; init; }
}
public record ImageSharpen : ComfyTypedNodeBase
{
public required ImageNodeConnection Image { get; init; }
public required int SharpenRadius { get; init; }
public required double Sigma { get; init; }
public required double Alpha { get; init; }
}
public record ControlNetLoader : ComfyTypedNodeBase
{
public required string ControlNetName { get; init; }
}
public record ControlNetApplyAdvanced
: ComfyTypedNodeBase
{
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required ControlNetNodeConnection ControlNet { get; init; }
public required ImageNodeConnection Image { get; init; }
public required double Strength { get; init; }
public required double StartPercent { get; init; }
public required double EndPercent { get; init; }
}
public ImageNodeConnection Lambda_LatentToImage(
LatentNodeConnection latent,
VAENodeConnection vae
)
{
var name = GetUniqueName("VAEDecode");
return Nodes
.AddTypedNode(
new VAEDecode
{
Name = name,
Samples = latent,
Vae = vae
}
)
.Output;
}
public LatentNodeConnection Lambda_ImageToLatent(
ImageNodeConnection pixels,
VAENodeConnection vae
)
{
var name = GetUniqueName("VAEEncode");
return Nodes
.AddTypedNode(
new VAEEncode
{
Name = name,
Pixels = pixels,
Vae = vae
}
)
.Output;
}
///
/// Create a group node that upscales a given image with a given model
///
public NamedComfyNode Group_UpscaleWithModel(
string name,
string modelName,
ImageNodeConnection image
)
{
var modelLoader = Nodes.AddNamedNode(
UpscaleModelLoader($"{name}_UpscaleModelLoader", modelName)
);
var upscaler = Nodes.AddNamedNode(
ImageUpscaleWithModel($"{name}_ImageUpscaleWithModel", modelLoader.Output, image)
);
return upscaler;
}
///
/// Create a group node that scales a given image to image output
///
public PrimaryNodeConnection Group_Upscale(
string name,
PrimaryNodeConnection primary,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return primary.Match(
latent =>
Nodes
.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
)
.Output,
image =>
Nodes
.AddNamedNode(
ImageScale(
$"{name}_ImageUpscale",
image,
upscaleInfo.Name,
height,
width,
false
)
)
.Output
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space if needed
var samplerImage = GetPrimaryAsImage(primary, vae);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
return resizedScaled.Output;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to a given size
///
public NamedComfyNode Group_UpscaleToLatent(
string name,
LatentNodeConnection latent,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = Nodes.AddTypedNode(
new VAEDecode
{
Name = $"{name}_VAEDecode",
Samples = latent,
Vae = vae
}
);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// Convert back to latent space
return Nodes.AddTypedNode(
new VAEEncode
{
Name = $"{name}_VAEEncode",
Pixels = resizedScaled.Output,
Vae = vae
}
);
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to image output
///
public NamedComfyNode Group_LatentUpscaleToImage(
string name,
LatentNodeConnection latent,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
var latentUpscale = Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
// Convert to image space
return Nodes.AddTypedNode(
new VAEDecode
{
Name = $"{name}_VAEDecode",
Samples = latentUpscale.Output,
Vae = vae
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = Nodes.AddTypedNode(
new VAEDecode
{
Name = $"{name}_VAEDecode",
Samples = latent,
Vae = vae
}
);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// No need to convert back to latent space
return resizedScaled;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to image output
///
public NamedComfyNode Group_UpscaleToImage(
string name,
ImageNodeConnection image,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "ImageScale",
Inputs = new Dictionary
{
["image"] = image,
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
}
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
image
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// No need to convert back to latent space
return resizedScaled;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that loads multiple Lora's in series
///
public NamedComfyNode Group_LoraLoadMany(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
IEnumerable<(string FileName, double? ModelWeight, double? ClipWeight)> loras
)
{
NamedComfyNode? currentNode = null;
foreach (var (i, loraNetwork) in loras.Enumerate())
{
currentNode = Nodes.AddNamedNode(
LoraLoader(
$"{name}_LoraLoader_{i + 1}",
model,
clip,
loraNetwork.FileName,
loraNetwork.ModelWeight ?? 1,
loraNetwork.ClipWeight ?? 1
)
);
// Connect to previous node
model = currentNode.Output1;
clip = currentNode.Output2;
}
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
///
/// Create a group node that loads multiple Lora's in series
///
public NamedComfyNode Group_LoraLoadMany(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
IEnumerable<(LocalModelFile ModelFile, double? ModelWeight, double? ClipWeight)> loras
)
{
NamedComfyNode? currentNode = null;
foreach (var (i, loraNetwork) in loras.Enumerate())
{
currentNode = Nodes.AddNamedNode(
LoraLoader(
$"{name}_LoraLoader_{i + 1}",
model,
clip,
loraNetwork.ModelFile.RelativePathFromSharedFolder,
loraNetwork.ModelWeight ?? 1,
loraNetwork.ClipWeight ?? 1
)
);
// Connect to previous node
model = currentNode.Output1;
clip = currentNode.Output2;
}
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
///
/// Get or convert latest primary connection to latent
///
public LatentNodeConnection GetPrimaryAsLatent()
{
if (Connections.Primary?.IsT0 == true)
{
return Connections.Primary.AsT0;
}
return GetPrimaryAsLatent(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.GetDefaultVAE()
);
}
///
/// Get or convert latest primary connection to latent
///
public LatentNodeConnection GetPrimaryAsLatent(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae));
}
///
/// Get or convert latest primary connection to latent
///
public LatentNodeConnection GetPrimaryAsLatent(VAENodeConnection vae)
{
if (Connections.Primary?.IsT0 == true)
{
return Connections.Primary.AsT0;
}
return GetPrimaryAsLatent(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
vae
);
}
///
/// Get or convert latest primary connection to image
///
public ImageNodeConnection GetPrimaryAsImage()
{
if (Connections.Primary?.IsT1 == true)
{
return Connections.Primary.AsT1;
}
return GetPrimaryAsImage(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.GetDefaultVAE()
);
}
///
/// Get or convert latest primary connection to image
///
public ImageNodeConnection GetPrimaryAsImage(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image);
}
///
/// Get or convert latest primary connection to image
///
public ImageNodeConnection GetPrimaryAsImage(VAENodeConnection vae)
{
if (Connections.Primary?.IsT1 == true)
{
return Connections.Primary.AsT1;
}
return GetPrimaryAsImage(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
vae
);
}
///
/// Convert to a NodeDictionary
///
public NodeDictionary ToNodeDictionary()
{
Nodes.NormalizeConnectionTypes();
return Nodes;
}
public class NodeBuilderConnections
{
public ulong Seed { get; set; }
public int BatchSize { get; set; } = 1;
public int? BatchIndex { get; set; }
public ModelNodeConnection? BaseModel { get; set; }
public VAENodeConnection? BaseVAE { get; set; }
public ClipNodeConnection? BaseClip { get; set; }
public ConditioningNodeConnection? BaseConditioning { get; set; }
public ConditioningNodeConnection? BaseNegativeConditioning { get; set; }
public ModelNodeConnection? RefinerModel { get; set; }
public VAENodeConnection? RefinerVAE { get; set; }
public ClipNodeConnection? RefinerClip { get; set; }
public ConditioningNodeConnection? RefinerConditioning { get; set; }
public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; }
public PrimaryNodeConnection? Primary { get; set; }
public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; }
public ComfySampler? PrimarySampler { get; set; }
public ComfyScheduler? PrimaryScheduler { get; set; }
public List OutputNodes { get; } = new();
public IEnumerable OutputNodeNames => OutputNodes.Select(n => n.Name);
public ModelNodeConnection GetRefinerOrBaseModel()
{
return RefinerModel ?? BaseModel ?? throw new NullReferenceException("No Model");
}
public ConditioningNodeConnection GetRefinerOrBaseConditioning()
{
return RefinerConditioning
?? BaseConditioning
?? throw new NullReferenceException("No Conditioning");
}
public ConditioningNodeConnection GetRefinerOrBaseNegativeConditioning()
{
return RefinerNegativeConditioning
?? BaseNegativeConditioning
?? throw new NullReferenceException("No Negative Conditioning");
}
public VAENodeConnection GetDefaultVAE()
{
return PrimaryVAE
?? RefinerVAE
?? BaseVAE
?? throw new NullReferenceException("No VAE");
}
}
public NodeBuilderConnections Connections { get; } = new();
}