using System.Diagnostics.CodeAnalysis;
using System.Drawing;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
///
/// Builder functions for comfy nodes
///
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public class ComfyNodeBuilder
{
public NodeDictionary Nodes { get; } = new();
private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8];
private string GetUniqueName(string nameBase)
{
var name = $"{nameBase}_1";
for (var i = 0; Nodes.ContainsKey(name); i++)
{
if (i > 1_000_000)
{
throw new InvalidOperationException(
$"Could not find unique name for base {nameBase}"
);
}
name = $"{nameBase}_{i + 1}";
}
return name;
}
public static NamedComfyNode VAEEncode(
string name,
ImageNodeConnection pixels,
VAENodeConnection vae
)
{
return new NamedComfyNode(name)
{
ClassType = "VAEEncode",
Inputs = new Dictionary
{
["pixels"] = pixels.Data,
["vae"] = vae.Data
}
};
}
public static NamedComfyNode VAEDecode(
string name,
LatentNodeConnection samples,
VAENodeConnection vae
)
{
return new NamedComfyNode(name)
{
ClassType = "VAEDecode",
Inputs = new Dictionary
{
["samples"] = samples.Data,
["vae"] = vae.Data
}
};
}
public static NamedComfyNode KSampler(
string name,
ModelNodeConnection model,
ulong seed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
double denoise
)
{
return new NamedComfyNode(name)
{
ClassType = "KSampler",
Inputs = new Dictionary
{
["model"] = model.Data,
["seed"] = seed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["denoise"] = denoise
}
};
}
public static NamedComfyNode KSamplerAdvanced(
string name,
ModelNodeConnection model,
bool addNoise,
ulong noiseSeed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
int startAtStep,
int endAtStep,
bool returnWithLeftoverNoise
)
{
return new NamedComfyNode(name)
{
ClassType = "KSamplerAdvanced",
Inputs = new Dictionary
{
["model"] = model.Data,
["add_noise"] = addNoise ? "enable" : "disable",
["noise_seed"] = noiseSeed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["start_at_step"] = startAtStep,
["end_at_step"] = endAtStep,
["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable"
}
};
}
public static NamedComfyNode EmptyLatentImage(
string name,
int batchSize,
int height,
int width
)
{
return new NamedComfyNode(name)
{
ClassType = "EmptyLatentImage",
Inputs = new Dictionary
{
["batch_size"] = batchSize,
["height"] = height,
["width"] = width,
}
};
}
public static NamedComfyNode LatentFromBatch(
string name,
LatentNodeConnection samples,
int batchIndex,
int length
)
{
return new NamedComfyNode(name)
{
ClassType = "LatentFromBatch",
Inputs = new Dictionary
{
["samples"] = samples.Data,
["batch_index"] = batchIndex,
["length"] = length,
}
};
}
public static NamedComfyNode ImageUpscaleWithModel(
string name,
UpscaleModelNodeConnection upscaleModel,
ImageNodeConnection image
)
{
return new NamedComfyNode(name)
{
ClassType = "ImageUpscaleWithModel",
Inputs = new Dictionary
{
["upscale_model"] = upscaleModel.Data,
["image"] = image.Data
}
};
}
public static NamedComfyNode UpscaleModelLoader(
string name,
string modelName
)
{
return new NamedComfyNode(name)
{
ClassType = "UpscaleModelLoader",
Inputs = new Dictionary { ["model_name"] = modelName }
};
}
public static NamedComfyNode ImageScale(
string name,
ImageNodeConnection image,
string method,
int height,
int width,
bool crop
)
{
return new NamedComfyNode(name)
{
ClassType = "ImageScale",
Inputs = new Dictionary
{
["image"] = image.Data,
["upscale_method"] = method,
["height"] = height,
["width"] = width,
["crop"] = crop ? "center" : "disabled"
}
};
}
public record VAELoader : ComfyTypedNodeBase
{
public required string VaeName { get; init; }
}
public static NamedComfyNode LoraLoader(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
string loraName,
double strengthModel,
double strengthClip
)
{
return new NamedComfyNode(name)
{
ClassType = "LoraLoader",
Inputs = new Dictionary
{
["model"] = model.Data,
["clip"] = clip.Data,
["lora_name"] = loraName,
["strength_model"] = strengthModel,
["strength_clip"] = strengthClip
}
};
}
public record CheckpointLoaderSimple
: ComfyTypedNodeBase
{
public required string CkptName { get; init; }
}
public static NamedComfyNode FreeU(
string name,
ModelNodeConnection model,
double b1,
double b2,
double s1,
double s2
)
{
return new NamedComfyNode(name)
{
ClassType = "FreeU",
Inputs = new Dictionary
{
["model"] = model.Data,
["b1"] = b1,
["b2"] = b2,
["s1"] = s1,
["s2"] = s2
}
};
}
public static NamedComfyNode ClipTextEncode(
string name,
ClipNodeConnection clip,
string text
)
{
return new NamedComfyNode(name)
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary { ["clip"] = clip.Data, ["text"] = text }
};
}
public record LoadImage : ComfyTypedNodeBase
{
///
/// Path relative to the Comfy input directory
///
public required string Image { get; init; }
}
public record PreviewImage : ComfyTypedNodeBase
{
public required ImageNodeConnection Images { get; init; }
}
public record ImageSharpen : ComfyTypedNodeBase
{
public required ImageNodeConnection Image { get; init; }
public required int SharpenRadius { get; init; }
public required double Sigma { get; init; }
public required double Alpha { get; init; }
}
public record ControlNetLoader : ComfyTypedNodeBase
{
public required string ControlNetName { get; init; }
}
public record ControlNetApplyAdvanced
: ComfyTypedNodeBase
{
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required ControlNetNodeConnection ControlNet { get; init; }
public required ImageNodeConnection Image { get; init; }
public required double Strength { get; init; }
public required double StartPercent { get; init; }
public required double EndPercent { get; init; }
}
public ImageNodeConnection Lambda_LatentToImage(
LatentNodeConnection latent,
VAENodeConnection vae
)
{
var name = GetUniqueName("VAEDecode");
return Nodes.AddNamedNode(VAEDecode(name, latent, vae)).Output;
}
public LatentNodeConnection Lambda_ImageToLatent(
ImageNodeConnection pixels,
VAENodeConnection vae
)
{
var name = GetUniqueName("VAEEncode");
return Nodes.AddNamedNode(VAEEncode(name, pixels, vae)).Output;
}
///
/// Create a group node that upscales a given image with a given model
///
public NamedComfyNode Group_UpscaleWithModel(
string name,
string modelName,
ImageNodeConnection image
)
{
var modelLoader = Nodes.AddNamedNode(
UpscaleModelLoader($"{name}_UpscaleModelLoader", modelName)
);
var upscaler = Nodes.AddNamedNode(
ImageUpscaleWithModel($"{name}_ImageUpscaleWithModel", modelLoader.Output, image)
);
return upscaler;
}
///
/// Create a group node that scales a given image to image output
///
public PrimaryNodeConnection Group_Upscale(
string name,
PrimaryNodeConnection primary,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return primary.Match(
latent =>
Nodes
.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
)
.Output,
image =>
Nodes
.AddNamedNode(
ImageScale(
$"{name}_ImageUpscale",
image,
upscaleInfo.Name,
height,
width,
false
)
)
.Output
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space if needed
var samplerImage = GetPrimaryAsImage(primary, vae);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
return resizedScaled.Output;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to a given size
///
public NamedComfyNode Group_UpscaleToLatent(
string name,
LatentNodeConnection latent,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = Nodes.AddNamedNode(VAEDecode($"{name}_VAEDecode", latent, vae));
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// Convert back to latent space
return Nodes.AddNamedNode(VAEEncode($"{name}_VAEEncode", resizedScaled.Output, vae));
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to image output
///
public NamedComfyNode Group_LatentUpscaleToImage(
string name,
LatentNodeConnection latent,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
var latentUpscale = Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
// Convert to image space
return Nodes.AddNamedNode(VAEDecode($"{name}_VAEDecode", latentUpscale.Output, vae));
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = Nodes.AddNamedNode(VAEDecode($"{name}_VAEDecode", latent, vae));
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// No need to convert back to latent space
return resizedScaled;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to image output
///
public NamedComfyNode Group_UpscaleToImage(
string name,
ImageNodeConnection image,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "ImageScale",
Inputs = new Dictionary
{
["image"] = image,
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
}
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
image
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// No need to convert back to latent space
return resizedScaled;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that loads multiple Lora's in series
///
public NamedComfyNode Group_LoraLoadMany(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
IEnumerable<(string FileName, double? ModelWeight, double? ClipWeight)> loras
)
{
NamedComfyNode? currentNode = null;
foreach (var (i, loraNetwork) in loras.Enumerate())
{
currentNode = Nodes.AddNamedNode(
LoraLoader(
$"{name}_LoraLoader_{i + 1}",
model,
clip,
loraNetwork.FileName,
loraNetwork.ModelWeight ?? 1,
loraNetwork.ClipWeight ?? 1
)
);
// Connect to previous node
model = currentNode.Output1;
clip = currentNode.Output2;
}
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
///
/// Create a group node that loads multiple Lora's in series
///
public NamedComfyNode Group_LoraLoadMany(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
IEnumerable<(LocalModelFile ModelFile, double? ModelWeight, double? ClipWeight)> loras
)
{
NamedComfyNode? currentNode = null;
foreach (var (i, loraNetwork) in loras.Enumerate())
{
currentNode = Nodes.AddNamedNode(
LoraLoader(
$"{name}_LoraLoader_{i + 1}",
model,
clip,
loraNetwork.ModelFile.RelativePathFromSharedFolder,
loraNetwork.ModelWeight ?? 1,
loraNetwork.ClipWeight ?? 1
)
);
// Connect to previous node
model = currentNode.Output1;
clip = currentNode.Output2;
}
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
///
/// Get or convert latest primary connection to latent
///
public LatentNodeConnection GetPrimaryAsLatent()
{
if (Connections.Primary?.IsT0 == true)
{
return Connections.Primary.AsT0;
}
return GetPrimaryAsLatent(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE")
);
}
///
/// Get or convert latest primary connection to latent
///
public LatentNodeConnection GetPrimaryAsLatent(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae));
}
///
/// Get or convert latest primary connection to latent
///
public LatentNodeConnection GetPrimaryAsLatent(VAENodeConnection vae)
{
if (Connections.Primary?.IsT0 == true)
{
return Connections.Primary.AsT0;
}
return GetPrimaryAsLatent(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
vae
);
}
///
/// Get or convert latest primary connection to image
///
public ImageNodeConnection GetPrimaryAsImage()
{
if (Connections.Primary?.IsT1 == true)
{
return Connections.Primary.AsT1;
}
return GetPrimaryAsImage(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
Connections.PrimaryVAE ?? throw new NullReferenceException("No primary VAE")
);
}
///
/// Get or convert latest primary connection to image
///
public ImageNodeConnection GetPrimaryAsImage(
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{
return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image);
}
///
/// Get or convert latest primary connection to image
///
public ImageNodeConnection GetPrimaryAsImage(VAENodeConnection vae)
{
if (Connections.Primary?.IsT1 == true)
{
return Connections.Primary.AsT1;
}
return GetPrimaryAsImage(
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
vae
);
}
///
/// Convert to a NodeDictionary
///
public NodeDictionary ToNodeDictionary()
{
Nodes.NormalizeConnectionTypes();
return Nodes;
}
public class NodeBuilderConnections
{
public ulong Seed { get; set; }
public ModelNodeConnection? BaseModel { get; set; }
public VAENodeConnection? BaseVAE { get; set; }
public ClipNodeConnection? BaseClip { get; set; }
public ConditioningNodeConnection? BaseConditioning { get; set; }
public ConditioningNodeConnection? BaseNegativeConditioning { get; set; }
public ModelNodeConnection? RefinerModel { get; set; }
public VAENodeConnection? RefinerVAE { get; set; }
public ClipNodeConnection? RefinerClip { get; set; }
public ConditioningNodeConnection? RefinerConditioning { get; set; }
public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; }
public PrimaryNodeConnection? Primary { get; set; }
public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; }
public List OutputNodes { get; } = new();
public IEnumerable OutputNodeNames => OutputNodes.Select(n => n.Name);
public ModelNodeConnection GetRefinerOrBaseModel()
{
return RefinerModel ?? BaseModel ?? throw new NullReferenceException("No Model");
}
public ConditioningNodeConnection GetRefinerOrBaseConditioning()
{
return RefinerConditioning
?? BaseConditioning
?? throw new NullReferenceException("No Conditioning");
}
public ConditioningNodeConnection GetRefinerOrBaseNegativeConditioning()
{
return RefinerNegativeConditioning
?? BaseNegativeConditioning
?? throw new NullReferenceException("No Negative Conditioning");
}
}
public NodeBuilderConnections Connections { get; } = new();
}