using System.Diagnostics.CodeAnalysis;
using System.Drawing;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.Tokens;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
///
/// Builder functions for comfy nodes
///
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public class ComfyNodeBuilder
{
public NodeDictionary Nodes { get; } = new();
public Dictionary GlobalConnections { get; } = new();
private static string GetRandomPrefix() => Guid.NewGuid().ToString()[..8];
public static NamedComfyNode VAEEncode(
string name,
ImageNodeConnection pixels,
VAENodeConnection vae
)
{
return new NamedComfyNode(name)
{
ClassType = "VAEEncode",
Inputs = new Dictionary
{
["pixels"] = pixels.Data,
["vae"] = vae.Data
}
};
}
public static NamedComfyNode VAEDecode(
string name,
LatentNodeConnection samples,
VAENodeConnection vae
)
{
return new NamedComfyNode(name)
{
ClassType = "VAEDecode",
Inputs = new Dictionary
{
["samples"] = samples.Data,
["vae"] = vae.Data
}
};
}
public static NamedComfyNode KSampler(
string name,
ModelNodeConnection model,
ulong seed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
double denoise
)
{
return new NamedComfyNode(name)
{
ClassType = "KSampler",
Inputs = new Dictionary
{
["model"] = model.Data,
["seed"] = seed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["denoise"] = denoise
}
};
}
public static NamedComfyNode KSamplerAdvanced(
string name,
ModelNodeConnection model,
bool addNoise,
ulong noiseSeed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
int startAtStep,
int endAtStep,
bool returnWithLeftoverNoise
)
{
return new NamedComfyNode(name)
{
ClassType = "KSamplerAdvanced",
Inputs = new Dictionary
{
["model"] = model.Data,
["add_noise"] = addNoise ? "enable" : "disable",
["noise_seed"] = noiseSeed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["start_at_step"] = startAtStep,
["end_at_step"] = endAtStep,
["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable"
}
};
}
public static NamedComfyNode EmptyLatentImage(
string name,
int batchSize,
int height,
int width
)
{
return new NamedComfyNode(name)
{
ClassType = "EmptyLatentImage",
Inputs = new Dictionary
{
["batch_size"] = batchSize,
["height"] = height,
["width"] = width,
}
};
}
public static NamedComfyNode ImageUpscaleWithModel(
string name,
UpscaleModelNodeConnection upscaleModel,
ImageNodeConnection image
)
{
return new NamedComfyNode(name)
{
ClassType = "ImageUpscaleWithModel",
Inputs = new Dictionary
{
["upscale_model"] = upscaleModel.Data,
["image"] = image.Data
}
};
}
public static NamedComfyNode UpscaleModelLoader(
string name,
string modelName
)
{
return new NamedComfyNode(name)
{
ClassType = "UpscaleModelLoader",
Inputs = new Dictionary { ["model_name"] = modelName }
};
}
public static NamedComfyNode ImageScale(
string name,
ImageNodeConnection image,
string method,
int height,
int width,
bool crop
)
{
return new NamedComfyNode(name)
{
ClassType = "ImageScale",
Inputs = new Dictionary
{
["image"] = image.Data,
["upscale_method"] = method,
["height"] = height,
["width"] = width,
["crop"] = crop ? "center" : "disabled"
}
};
}
public static NamedComfyNode VAELoader(string name, string vaeModelName)
{
return new NamedComfyNode(name)
{
ClassType = "VAELoader",
Inputs = new Dictionary { ["vae_name"] = vaeModelName }
};
}
public static NamedComfyNode LoraLoader(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
string loraName,
double strengthModel,
double strengthClip
)
{
return new NamedComfyNode(name)
{
ClassType = "LoraLoader",
Inputs = new Dictionary
{
["model"] = model.Data,
["clip"] = clip.Data,
["lora_name"] = loraName,
["strength_model"] = strengthModel,
["strength_clip"] = strengthClip
}
};
}
public static NamedComfyNode CheckpointLoaderSimple(
string name,
string modelName
)
{
return new NamedComfyNode(name)
{
ClassType = "CheckpointLoaderSimple",
Inputs = new Dictionary { ["ckpt_name"] = modelName }
};
}
public static NamedComfyNode ClipTextEncode(
string name,
ClipNodeConnection clip,
string text
)
{
return new NamedComfyNode(name)
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary { ["clip"] = clip.Data, ["text"] = text }
};
}
public ImageNodeConnection Lambda_LatentToImage(
LatentNodeConnection latent,
VAENodeConnection vae
)
{
return Nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output;
}
public LatentNodeConnection Lambda_ImageToLatent(
ImageNodeConnection pixels,
VAENodeConnection vae
)
{
return Nodes.AddNamedNode(VAEEncode($"{GetRandomPrefix()}_VAEEncode", pixels, vae)).Output;
}
///
/// Get a global connection for a given type
///
public TConnection GetConnection()
where TConnection : NodeConnectionBase
{
if (GlobalConnections.TryGetValue(typeof(TConnection), out var connection))
{
return (TConnection)connection;
}
throw new InvalidOperationException($"No global connection of type {typeof(TConnection)}");
}
///
/// Set a global connection for a given type
///
public void SetConnection(TConnection connection)
where TConnection : NodeConnectionBase
{
GlobalConnections[typeof(TConnection)] = connection;
}
///
/// Create a group node that upscales a given image with a given model
///
public NamedComfyNode Group_UpscaleWithModel(
string name,
string modelName,
ImageNodeConnection image
)
{
var modelLoader = Nodes.AddNamedNode(
UpscaleModelLoader($"{name}_UpscaleModelLoader", modelName)
);
var upscaler = Nodes.AddNamedNode(
ImageUpscaleWithModel($"{name}_ImageUpscaleWithModel", modelLoader.Output, image)
);
return upscaler;
}
///
/// Create a group node that scales a given image to a given size
///
public NamedComfyNode Group_UpscaleToLatent(
string name,
LatentNodeConnection latent,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = Nodes.AddNamedNode(VAEDecode($"{name}_VAEDecode", latent, vae));
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// Convert back to latent space
return Nodes.AddNamedNode(VAEEncode($"{name}_VAEEncode", resizedScaled.Output, vae));
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that scales a given image to image output
///
public NamedComfyNode Group_UpscaleToImage(
string name,
LatentNodeConnection latent,
VAENodeConnection vae,
ComfyUpscaler upscaleInfo,
int width,
int height
)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
var latentUpscale = Nodes.AddNamedNode(
new NamedComfyNode($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
// Convert to image space
return Nodes.AddNamedNode(VAEDecode($"{name}_VAEDecode", latentUpscale.Output, vae));
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = Nodes.AddNamedNode(VAEDecode($"{name}_VAEDecode", latent, vae));
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// No need to convert back to latent space
return resizedScaled;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
///
/// Create a group node that loads multiple Lora's in series
///
public NamedComfyNode Group_LoraLoadMany(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
IEnumerable<(string FileName, double? ModelWeight, double? ClipWeight)> loras
)
{
NamedComfyNode? currentNode = null;
foreach (var (i, loraNetwork) in loras.Enumerate())
{
currentNode = Nodes.AddNamedNode(
LoraLoader(
$"{name}_LoraLoader_{i + 1}",
model,
clip,
loraNetwork.FileName,
loraNetwork.ModelWeight ?? 1,
loraNetwork.ClipWeight ?? 1
)
);
// Connect to previous node
model = currentNode.Output1;
clip = currentNode.Output2;
}
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
///
/// Create a group node that loads multiple Lora's in series
///
public NamedComfyNode Group_LoraLoadMany(
string name,
ModelNodeConnection model,
ClipNodeConnection clip,
IEnumerable<(LocalModelFile ModelFile, double? ModelWeight, double? ClipWeight)> loras
)
{
NamedComfyNode? currentNode = null;
foreach (var (i, loraNetwork) in loras.Enumerate())
{
currentNode = Nodes.AddNamedNode(
LoraLoader(
$"{name}_LoraLoader_{i + 1}",
model,
clip,
loraNetwork.ModelFile.FileName,
loraNetwork.ModelWeight ?? 1,
loraNetwork.ClipWeight ?? 1
)
);
// Connect to previous node
model = currentNode.Output1;
clip = currentNode.Output2;
}
return currentNode ?? throw new InvalidOperationException("No lora networks given");
}
///
/// Convert to a NodeDictionary
///
public NodeDictionary ToNodeDictionary()
{
Nodes.NormalizeConnectionTypes();
return Nodes;
}
public class NodeBuilderConnections
{
public ModelNodeConnection? BaseModel { get; set; }
public VAENodeConnection? BaseVAE { get; set; }
public ConditioningNodeConnection? BaseConditioning { get; set; }
public ConditioningNodeConnection? BaseNegativeConditioning { get; set; }
public ModelNodeConnection? RefinerModel { get; set; }
public VAENodeConnection? RefinerVAE { get; set; }
public ConditioningNodeConnection? RefinerConditioning { get; set; }
public ConditioningNodeConnection? RefinerNegativeConditioning { get; set; }
public LatentNodeConnection? Latent { get; set; }
public Size LatentSize { get; set; }
public ImageNodeConnection? Image { get; set; }
///
/// Gets the latent size scaled by a given factor
///
public Size GetScaledLatentSize(double scale)
{
return new Size(
(int)Math.Floor(LatentSize.Width * scale),
(int)Math.Floor(LatentSize.Height * scale)
);
}
public VAENodeConnection GetRefinerOrBaseVAE()
{
return RefinerVAE ?? BaseVAE ?? throw new NullReferenceException("No VAE");
}
public ModelNodeConnection GetRefinerOrBaseModel()
{
return RefinerModel ?? BaseModel ?? throw new NullReferenceException("No Model");
}
public ConditioningNodeConnection GetRefinerOrBaseConditioning()
{
return RefinerConditioning
?? BaseConditioning
?? throw new NullReferenceException("No Conditioning");
}
public ConditioningNodeConnection GetRefinerOrBaseNegativeConditioning()
{
return RefinerNegativeConditioning
?? BaseNegativeConditioning
?? throw new NullReferenceException("No Negative Conditioning");
}
}
public NodeBuilderConnections Connections { get; } = new();
}