Browse Source

Move inference logic to ApplyStep

pull/333/head
Ionite 12 months ago
parent
commit
5bc1715a49
No known key found for this signature in database
  1. 307
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 162
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  3. 14
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  4. 52
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs
  5. 48
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs
  6. 118
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  7. 138
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
  8. 8
      StabilityMatrix.Core/Attributes/BoolStringMemberAttribute.cs
  9. 100
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  10. 18
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs

307
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -1,40 +1,38 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Drawing;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.Extensions;
public static class ComfyNodeBuilderExtensions
{
public static void SetupLatentSource(
public static void SetupEmptyLatentSource(
this ComfyNodeBuilder builder,
BatchSizeCardViewModel batchSizeCardViewModel,
SamplerCardViewModel samplerCardViewModel
int width,
int height,
int batchSize = 1,
int? batchIndex = null
)
{
var emptyLatent = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.EmptyLatentImage(
"EmptyLatentImage",
batchSizeCardViewModel.BatchSize,
samplerCardViewModel.Height,
samplerCardViewModel.Width
)
var emptyLatent = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.EmptyLatentImage
{
Name = "EmptyLatentImage",
BatchSize = batchSize,
Height = height,
Width = width
}
);
builder.Connections.Primary = emptyLatent.Output;
builder.Connections.PrimarySize = new Size(
samplerCardViewModel.Width,
samplerCardViewModel.Height
);
builder.Connections.PrimarySize = new Size(width, height);
// If batch index is selected, add a LatentFromBatch
if (batchSizeCardViewModel.IsBatchIndexEnabled)
if (batchIndex is not null)
{
builder.Connections.Primary = builder.Nodes
.AddNamedNode(
@ -42,7 +40,7 @@ public static class ComfyNodeBuilderExtensions
"LatentFromBatch",
builder.GetPrimaryAsLatent(),
// remote expects a 0-based index, vm is 1-based
batchSizeCardViewModel.BatchIndex - 1,
batchIndex.Value - 1,
1
)
)
@ -50,256 +48,63 @@ public static class ComfyNodeBuilderExtensions
}
}
public static void SetupBaseSampler(
public static void SetupImageLatentSource(
this ComfyNodeBuilder builder,
SamplerCardViewModel samplerCardViewModel,
PromptCardViewModel promptCardViewModel,
ModelCardViewModel modelCardViewModel,
IModelIndexService modelIndexService,
Action<ComfyNodeBuilder>? postModelLoad = null
BatchSizeCardViewModel batchSizeCardViewModel,
SamplerCardViewModel samplerCardViewModel
)
{
/*// Load base checkpoint
var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple(
"CheckpointLoader",
modelCardViewModel.SelectedModel?.FileName
?? throw new NullReferenceException("Model not selected")
)
);
builder.Connections.BaseModel = checkpointLoader.GetOutput<ModelNodeConnection>(0);
builder.Connections.BaseClip = checkpointLoader.GetOutput<ClipNodeConnection>(1);
builder.Connections.BaseVAE = checkpointLoader.GetOutput<VAENodeConnection>(2);
builder.Connections.PrimaryVAE = builder.Connections.BaseVAE;
// Run post model load action
postModelLoad?.Invoke(builder);*/
// Load prompts
var prompt = promptCardViewModel.GetPrompt();
prompt.Process();
var negativePrompt = promptCardViewModel.GetNegativePrompt();
negativePrompt.Process();
// If need to load loras, add a group
if (prompt.ExtraNetworks.Count > 0)
{
// Convert to local file names
var lorasGroup = builder.Group_LoraLoadMany(
"Loras",
builder.Connections.BaseModel,
builder.Connections.BaseClip,
prompt.GetExtraNetworksAsLocalModels(modelIndexService)
);
// Set as source
builder.Connections.BaseModel = lorasGroup.Output1;
builder.Connections.BaseClip = lorasGroup.Output2;
}
// Clips
var positiveClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"PositiveCLIP",
builder.Connections.BaseClip,
prompt.ProcessedText
)
);
var negativeClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"NegativeCLIP",
builder.Connections.BaseClip,
negativePrompt.ProcessedText
)
);
builder.Connections.BaseConditioning = positiveClip.Output;
builder.Connections.BaseNegativeConditioning = negativeClip.Output;
// Apply sampler addons (FreeU / ControlNet) to model and conditioning
var samplerStepArgs = new ModuleApplyStepEventArgs
{
Builder = builder,
Temp =
var emptyLatent = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.EmptyLatentImage
{
Model = builder.Connections.BaseModel,
Conditioning = (positiveClip.Output, negativeClip.Output)
Name = "EmptyLatentImage",
BatchSize = batchSizeCardViewModel.BatchSize,
Height = samplerCardViewModel.Height,
Width = samplerCardViewModel.Width
}
};
samplerCardViewModel.ApplyStep(samplerStepArgs);
var model = samplerStepArgs.Temp.Model;
var conditioning = samplerStepArgs.Temp.Conditioning;
// Primary latent encoding vae
var vae =
builder.Connections.PrimaryVAE
?? builder.Connections.BaseVAE
?? throw new ValidationException("No Primary or Base VAE");
var latent = builder.GetPrimaryAsLatent(vae);
// Add base sampler (without refiner)
if (
modelCardViewModel
is not { IsRefinerSelectionEnabled: true, SelectedRefiner.IsDefault: false }
)
{
var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSampler(
"Sampler",
model,
builder.Connections.Seed,
samplerCardViewModel.Steps,
samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Scheduler not selected"),
conditioning.Positive,
conditioning.Negative,
latent,
samplerCardViewModel.DenoiseStrength
)
);
builder.Connections.Primary = sampler.Output;
}
// Add base sampler (with refiner)
else
{
// Total steps is the sum of the base and refiner steps
var totalSteps = samplerCardViewModel.Steps + samplerCardViewModel.RefinerSteps;
var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSamplerAdvanced(
"Sampler",
model,
true,
builder.Connections.Seed,
totalSteps,
samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Sampler not selected"),
conditioning.Positive,
conditioning.Negative,
latent,
0,
samplerCardViewModel.Steps,
true
)
);
builder.Connections.Primary = sampler.Output;
}
}
public static void SetupRefinerSampler(
this ComfyNodeBuilder builder,
SamplerCardViewModel samplerCardViewModel,
PromptCardViewModel promptCardViewModel,
ModelCardViewModel modelCardViewModel,
IModelIndexService modelIndexService,
Action<ComfyNodeBuilder>? postModelLoad = null
)
{
/*// Load refiner checkpoint
var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple(
"Refiner_CheckpointLoader",
modelCardViewModel.SelectedRefiner?.RelativePath
?? throw new NullReferenceException("Model not selected")
)
);
builder.Connections.RefinerModel = checkpointLoader.GetOutput<ModelNodeConnection>(0);
builder.Connections.RefinerClip = checkpointLoader.GetOutput<ClipNodeConnection>(1);
builder.Connections.RefinerVAE = checkpointLoader.GetOutput<VAENodeConnection>(2);
builder.Connections.PrimaryVAE = builder.Connections.RefinerVAE;
// Run post model load action
postModelLoad?.Invoke(builder);*/
// Load prompts
var prompt = promptCardViewModel.GetPrompt();
prompt.Process();
var negativePrompt = promptCardViewModel.GetNegativePrompt();
negativePrompt.Process();
builder.Connections.Primary = emptyLatent.Output;
builder.Connections.PrimarySize = new Size(
samplerCardViewModel.Width,
samplerCardViewModel.Height
);
// If need to load loras, add a group
if (prompt.ExtraNetworks.Count > 0)
// If batch index is selected, add a LatentFromBatch
if (batchSizeCardViewModel.IsBatchIndexEnabled)
{
// Convert to local file names
var lorasGroup = builder.Group_LoraLoadMany(
"Refiner_Loras",
builder.Connections.RefinerModel,
builder.Connections.RefinerClip,
prompt.GetExtraNetworksAsLocalModels(modelIndexService)
);
// Set as source
builder.Connections.RefinerModel = lorasGroup.Output1;
builder.Connections.RefinerClip = lorasGroup.Output2;
builder.Connections.Primary = builder.Nodes
.AddNamedNode(
ComfyNodeBuilder.LatentFromBatch(
"LatentFromBatch",
builder.GetPrimaryAsLatent(),
// remote expects a 0-based index, vm is 1-based
batchSizeCardViewModel.BatchIndex - 1,
1
)
)
.Output;
}
// Clips
var positiveClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"Refiner_PositiveCLIP",
builder.Connections.RefinerClip,
prompt.ProcessedText
)
);
var negativeClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"Refiner_NegativeCLIP",
builder.Connections.RefinerClip,
negativePrompt.ProcessedText
)
);
builder.Connections.RefinerConditioning = positiveClip.Output;
builder.Connections.RefinerNegativeConditioning = negativeClip.Output;
// Add refiner sampler
// Total steps is the sum of the base and refiner steps
var totalSteps = samplerCardViewModel.Steps + samplerCardViewModel.RefinerSteps;
var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSamplerAdvanced(
"Refiner_Sampler",
builder.Connections.RefinerModel,
false,
builder.Connections.Seed,
totalSteps,
samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Sampler not selected"),
positiveClip.Output,
negativeClip.Output,
builder.GetPrimaryAsLatent(),
samplerCardViewModel.Steps,
totalSteps,
false
)
);
builder.Connections.Primary = sampler.Output;
}
public static string SetupOutputImage(this ComfyNodeBuilder builder)
{
var previewImage = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.PreviewImage
{
Name = "SaveImage",
Images = builder.GetPrimaryAsImage(
if (builder.Connections.Primary is null)
throw new ArgumentException("No Primary");
var image = builder.Connections.Primary.Match(
_ =>
builder.GetPrimaryAsImage(
builder.Connections.PrimaryVAE
?? builder.Connections.RefinerVAE
?? builder.Connections.BaseVAE
)
}
?? throw new ArgumentException("No Primary, Refiner, or Base VAE")
),
image => image
);
var previewImage = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.PreviewImage { Name = "SaveImage", Images = image }
);
builder.Connections.OutputNodes.Add(previewImage);

162
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -148,37 +148,11 @@ public class InferenceTextToImageViewModel
StackCardViewModel = vmFactory.Get<StackCardViewModel>();
StackCardViewModel.AddCards(
new LoadableViewModelBase[]
{
ModelCardViewModel,
SamplerCardViewModel,
ModulesCardViewModel,
/*// Free U
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "FreeU";
stackExpander.AddCards(new LoadableViewModelBase[] { FreeUCardViewModel });
}),
// Hires Fix
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Hires Fix";
stackExpander.AddCards(
new LoadableViewModelBase[]
{
HiresUpscalerCardViewModel,
HiresSamplerCardViewModel
}
);
}),
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Upscale";
stackExpander.AddCards(new LoadableViewModelBase[] { UpscalerCardViewModel });
}),*/
SeedCardViewModel,
BatchSizeCardViewModel,
}
ModelCardViewModel,
SamplerCardViewModel,
ModulesCardViewModel,
SeedCardViewModel,
BatchSizeCardViewModel
);
// When refiner is provided in model card, enable for sampler
@ -196,128 +170,36 @@ public class InferenceTextToImageViewModel
{
base.BuildPrompt(args);
using var _ = CodeTimer.StartDebug();
var builder = args.Builder;
var nodes = builder.Nodes;
if (args.SeedOverride is { } seed)
{
builder.Connections.Seed = Convert.ToUInt64(seed);
}
else
builder.Connections.Seed = args.SeedOverride switch
{
builder.Connections.Seed = Convert.ToUInt64(SeedCardViewModel.Seed);
}
{ } seed => Convert.ToUInt64(seed),
_ => Convert.ToUInt64(SeedCardViewModel.Seed)
};
// Load models
ModelCardViewModel.ApplyStep(args);
// Setup empty latent
builder.SetupLatentSource(BatchSizeCardViewModel, SamplerCardViewModel);
// Setup base sampling stage
builder.SetupBaseSampler(
SamplerCardViewModel,
PromptCardViewModel,
ModelCardViewModel,
modelIndexService
builder.SetupEmptyLatentSource(
SamplerCardViewModel.Width,
SamplerCardViewModel.Height,
BatchSizeCardViewModel.BatchSize,
BatchSizeCardViewModel.IsBatchIndexEnabled ? BatchSizeCardViewModel.BatchIndex : null
);
// Setup refiner stage if enabled
if (
ModelCardViewModel is
{ IsRefinerSelectionEnabled: true, SelectedRefiner.IsDefault: false }
)
{
builder.SetupRefinerSampler(
SamplerCardViewModel,
PromptCardViewModel,
ModelCardViewModel,
modelIndexService
);
}
// Prompts and loras
PromptCardViewModel.ApplyStep(args);
// Override with custom VAE if enabled
/*if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false })
{
var customVaeLoader = nodes.AddNamedNode(
ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath)
);
// Setup Sampler and Refiner if enabled
SamplerCardViewModel.ApplyStep(args);
builder.Connections.PrimaryVAE = customVaeLoader.Output;
}*/
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
/*if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
// Hires fix if enabled
foreach (var module in ModulesCardViewModel.Cards.OfType<ModuleBase>())
{
// Get new latent size
var hiresSize = builder.Connections.PrimarySize.WithScale(
HiresUpscalerCardViewModel.Scale
);
// Select between latent upscale and normal upscale based on the upscale method
var selectedUpscaler = HiresUpscalerCardViewModel.SelectedUpscaler!.Value;
// If upscaler selected, upscale latent image first
if (selectedUpscaler.Type != ComfyUpscalerType.None)
{
builder.Connections.Primary = builder.Group_Upscale(
"HiresFix",
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
selectedUpscaler,
hiresSize.Width,
hiresSize.Height
);
}
// Use refiner model if set, or base if not
var hiresSampler = nodes.AddNamedNode(
ComfyNodeBuilder.KSampler(
"HiresSampler",
builder.Connections.GetRefinerOrBaseModel(),
builder.Connections.Seed,
HiresSamplerCardViewModel.Steps,
HiresSamplerCardViewModel.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler
HiresSamplerCardViewModel.SelectedSampler
?? SamplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
HiresSamplerCardViewModel.SelectedScheduler
?? SamplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Scheduler not selected"),
builder.Connections.GetRefinerOrBaseConditioning(),
builder.Connections.GetRefinerOrBaseNegativeConditioning(),
builder.GetPrimaryAsLatent(),
HiresSamplerCardViewModel.DenoiseStrength
)
);
// Set as primary
builder.Connections.Primary = hiresSampler.Output;
builder.Connections.PrimarySize = hiresSize;
}*/
// If upscale is enabled, add another upscale group
/*if (IsUpscaleEnabled)
{
var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
var upscaleResult = builder.Group_Upscale(
"PostUpscale",
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width,
upscaleSize.Height
);
builder.Connections.Primary = upscaleResult;
builder.Connections.PrimarySize = upscaleSize;
}*/
module.ApplyStep(args);
}
builder.SetupOutputImage();
}

14
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -129,13 +129,6 @@ public partial class ModelCardViewModel
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
}
internal class ModelCardModel
{
public string? SelectedModelName { get; init; }
public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; }
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
@ -182,4 +175,11 @@ public partial class ModelCardViewModel
ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256
};
}
internal class ModelCardModel
{
public string? SelectedModelName { get; init; }
public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; }
}
}

52
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs

@ -1,5 +1,4 @@
using System;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
@ -26,20 +25,41 @@ public class FreeUModule : ModuleBase
{
var card = GetCard<FreeUCardViewModel>();
e.Temp.Model = e.Nodes
.AddNamedNode(
ComfyNodeBuilder.FreeU(
e.Nodes.GetUniqueName("FreeU"),
e.Temp.Model
?? throw new ArgumentException(
"Temp.Model not set on ModuleApplyStepEventArgs"
),
card.B1,
card.B2,
card.S1,
card.S2
// Currently applies to both base and refiner model
// TODO: Add option to apply to either base or refiner
if (e.Builder.Connections.BaseModel is not null)
{
e.Builder.Connections.BaseModel = e.Nodes
.AddTypedNode(
new ComfyNodeBuilder.FreeU
{
Name = e.Nodes.GetUniqueName("FreeU"),
Model = e.Builder.Connections.BaseModel,
B1 = card.B1,
B2 = card.B2,
S1 = card.S1,
S2 = card.S2
}
)
.Output;
}
if (e.Builder.Connections.RefinerModel is not null)
{
e.Builder.Connections.RefinerModel = e.Nodes
.AddTypedNode(
new ComfyNodeBuilder.FreeU
{
Name = e.Nodes.GetUniqueName("Refiner_FreeU"),
Model = e.Builder.Connections.RefinerModel,
B1 = card.B1,
B2 = card.B2,
S1 = card.S1,
S2 = card.S2
}
)
)
.Output;
.Output;
}
}
}

48
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs

@ -1,4 +1,5 @@
using System.ComponentModel.DataAnnotations;
using System;
using System.ComponentModel.DataAnnotations;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
@ -46,34 +47,35 @@ public class HiresFixModule : ModuleBase
{
builder.Connections.Primary = builder.Group_Upscale(
"HiresFix",
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
builder.Connections.Primary ?? throw new ArgumentException("No Primary"),
builder.Connections.PrimaryVAE ?? throw new ArgumentException("No PrimaryVAE"),
selectedUpscaler,
hiresSize.Width,
hiresSize.Height
);
}
// Use refiner model if set, or base if not
var hiresSampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSampler(
builder.Nodes.GetUniqueName("HiresFix_Sampler"),
builder.Connections.GetRefinerOrBaseModel(),
builder.Connections.Seed,
samplerCard.Steps,
samplerCard.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler
samplerCard.SelectedSampler
?? samplerCard.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCard.SelectedScheduler
?? samplerCard.SelectedScheduler
?? throw new ValidationException("Scheduler not selected"),
builder.Connections.GetRefinerOrBaseConditioning(),
builder.Connections.GetRefinerOrBaseNegativeConditioning(),
builder.GetPrimaryAsLatent(),
samplerCard.DenoiseStrength
)
var hiresSampler = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSampler
{
Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"),
Model = builder.Connections.GetRefinerOrBaseModel(),
Seed = builder.Connections.Seed,
Steps = samplerCard.Steps,
Cfg = samplerCard.CfgScale,
SamplerName =
samplerCard.SelectedSampler?.Name
?? e.Builder.Connections.PrimarySampler?.Name
?? throw new ArgumentException("No PrimarySampler"),
Scheduler =
samplerCard.SelectedScheduler?.Name
?? e.Builder.Connections.PrimaryScheduler?.Name
?? throw new ArgumentException("No PrimaryScheduler"),
Positive = builder.Connections.GetRefinerOrBaseConditioning(),
Negative = builder.Connections.GetRefinerOrBaseNegativeConditioning(),
LatentImage = builder.GetPrimaryAsLatent(),
Denoise = samplerCard.DenoiseStrength
}
);
// Set as primary

118
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,4 +1,6 @@
using System.Text;
using System;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
@ -16,6 +18,7 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -23,7 +26,10 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))]
[ManagedService]
[Transient]
public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState
public partial class PromptCardViewModel
: LoadableViewModelBase,
IParametersLoadableState,
IComfyStep
{
private readonly IModelIndexService modelIndexService;
@ -64,6 +70,114 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
);
}
/// <summary>
/// Applies the prompt step.
/// Requires:
/// <list type="number">
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseModel"/></item>
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseClip"/></item>
/// </list>
/// Provides:
/// <list type="number">
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseConditioning"/></item>
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseNegativeConditioning"/></item>
/// </list>
/// </summary>
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Load prompts
var positivePrompt = GetPrompt();
positivePrompt.Process();
var negativePrompt = GetNegativePrompt();
negativePrompt.Process();
// If need to load loras, add a group
if (positivePrompt.ExtraNetworks.Count > 0)
{
var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList();
// Add group to load loras onto model and clip in series
var lorasGroup = e.Builder.Group_LoraLoadMany(
"Loras",
e.Builder.Connections.BaseModel ?? throw new ArgumentException("BaseModel is null"),
e.Builder.Connections.BaseClip ?? throw new ArgumentException("BaseClip is null"),
loras
);
// Set last outputs as base model and clip
e.Builder.Connections.BaseModel = lorasGroup.Output1;
e.Builder.Connections.BaseClip = lorasGroup.Output2;
// Refiner loras
if (e.Builder.Connections.RefinerModel is not null)
{
// Add group to load loras onto refiner model and clip in series
var lorasGroupRefiner = e.Builder.Group_LoraLoadMany(
"Refiner_Loras",
e.Builder.Connections.RefinerModel
?? throw new ArgumentException("RefinerModel is null"),
e.Builder.Connections.RefinerClip
?? throw new ArgumentException("RefinerClip is null"),
loras
);
// Set last outputs as refiner model and clip
e.Builder.Connections.RefinerModel = lorasGroupRefiner.Output1;
e.Builder.Connections.RefinerClip = lorasGroupRefiner.Output2;
}
}
// Clips
var positiveClip = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "PositiveCLIP",
Clip = e.Builder.Connections.BaseClip!,
Text = positivePrompt.ProcessedText
}
);
var negativeClip = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "NegativeCLIP",
Clip = e.Builder.Connections.BaseClip!,
Text = negativePrompt.ProcessedText
}
);
// Set base conditioning from Clips
e.Builder.Connections.BaseConditioning = positiveClip.Output;
e.Builder.Connections.BaseNegativeConditioning = negativeClip.Output;
// Refiner Clips
if (e.Builder.Connections.RefinerModel is not null)
{
var positiveClipRefiner = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "Refiner_PositiveCLIP",
Clip =
e.Builder.Connections.RefinerClip
?? throw new ArgumentException("RefinerClip is null"),
Text = positivePrompt.ProcessedText
}
);
var negativeClipRefiner = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "Refiner_NegativeCLIP",
Clip =
e.Builder.Connections.RefinerClip
?? throw new ArgumentException("RefinerClip is null"),
Text = negativePrompt.ProcessedText
}
);
// Set refiner conditioning from Clips
e.Builder.Connections.RefinerConditioning = positiveClipRefiner.Output;
e.Builder.Connections.RefinerNegativeConditioning = negativeClipRefiner.Output;
}
}
/// <summary>
/// Gets the tokenized Prompt for given text and caches it
/// </summary>

138
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -1,4 +1,6 @@
using System.ComponentModel.DataAnnotations;
using System;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel;
@ -12,6 +14,7 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -75,6 +78,8 @@ public partial class SamplerCardViewModel
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
private int TotalSteps => Steps + RefinerSteps;
public SamplerCardViewModel(
IInferenceClientManager clientManager,
ServiceManager<ViewModelBase> vmFactory
@ -102,6 +107,137 @@ public partial class SamplerCardViewModel
/// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Apply steps from our addons
ApplyAddonSteps(e);
// If "Sampler" is not yet a node, do initial setup
// otherwise do hires setup
if (!e.Nodes.ContainsKey("Sampler"))
{
ApplyStepsInitialSampler(e);
}
else
{
ApplyStepsAdditionalSampler(e);
}
}
private void ApplyStepsInitialSampler(ModuleApplyStepEventArgs e)
{
// Get primary or base VAE
var vae =
e.Builder.Connections.PrimaryVAE
?? e.Builder.Connections.BaseVAE
?? throw new ArgumentException("No Primary or Base VAE");
// Get primary as latent using vae
var primaryLatent = e.Builder.GetPrimaryAsLatent(vae);
// Set primary sampler and scheduler
e.Builder.Connections.PrimarySampler =
SelectedSampler ?? throw new ValidationException("Sampler not selected");
e.Builder.Connections.PrimaryScheduler =
SelectedScheduler ?? throw new ValidationException("Scheduler not selected");
// Use KSampler if no refiner, otherwise need KSamplerAdvanced
if (e.Builder.Connections.RefinerModel is null)
{
// No refiner
var sampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSampler
{
Name = "Sampler",
Model =
e.Builder.Connections.BaseModel
?? throw new ArgumentException("No BaseModel"),
Seed = e.Builder.Connections.Seed,
SamplerName = e.Builder.Connections.PrimarySampler?.Name!,
Scheduler = e.Builder.Connections.PrimaryScheduler?.Name!,
Steps = Steps,
Cfg = CfgScale,
Positive =
e.Builder.Connections.BaseConditioning
?? throw new ArgumentException("No BaseConditioning"),
Negative =
e.Builder.Connections.BaseNegativeConditioning
?? throw new ArgumentException("No BaseNegativeConditioning"),
LatentImage = primaryLatent,
Denoise = DenoiseStrength,
}
);
e.Builder.Connections.Primary = sampler.Output;
}
else
{
// Advanced base sampler for refiner
var sampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSamplerAdvanced
{
Name = "Sampler",
Model =
e.Builder.Connections.BaseModel
?? throw new ArgumentException("No BaseModel"),
AddNoise = true,
NoiseSeed = e.Builder.Connections.Seed,
Steps = TotalSteps,
Cfg = CfgScale,
Sampler = e.Builder.Connections.PrimarySampler?.Name!,
Scheduler = e.Builder.Connections.PrimaryScheduler?.Name!,
Positive =
e.Builder.Connections.BaseConditioning
?? throw new ArgumentException("No BaseConditioning"),
Negative =
e.Builder.Connections.BaseNegativeConditioning
?? throw new ArgumentException("No BaseNegativeConditioning"),
LatentImage = primaryLatent,
StartAtStep = 0,
EndAtStep = TotalSteps,
ReturnWithLeftoverNoise = true
}
);
// Add refiner sampler
var refinerSampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSamplerAdvanced
{
Name = "Refiner_Sampler",
Model =
e.Builder.Connections.RefinerModel
?? throw new ArgumentException("No RefinerModel"),
AddNoise = false,
NoiseSeed = e.Builder.Connections.Seed,
Steps = TotalSteps,
Cfg = CfgScale,
Sampler = e.Builder.Connections.PrimarySampler?.Name!,
Scheduler = e.Builder.Connections.PrimaryScheduler?.Name!,
Positive =
e.Builder.Connections.RefinerConditioning
?? throw new ArgumentException("No RefinerConditioning"),
Negative =
e.Builder.Connections.RefinerNegativeConditioning
?? throw new ArgumentException("No RefinerNegativeConditioning"),
// Connect to previous sampler
LatentImage = sampler.Output,
StartAtStep = Steps,
EndAtStep = TotalSteps,
ReturnWithLeftoverNoise = false
}
);
e.Builder.Connections.Primary = refinerSampler.Output;
}
}
private void ApplyStepsAdditionalSampler(ModuleApplyStepEventArgs e) { }
/// <summary>
/// Applies each step of our addons
/// </summary>
/// <param name="e"></param>
private void ApplyAddonSteps(ModuleApplyStepEventArgs e)
{
// Apply steps from our modules
foreach (var module in ModulesCardViewModel.Cards.Cast<ModuleBase>())

8
StabilityMatrix.Core/Attributes/BoolStringMemberAttribute.cs

@ -0,0 +1,8 @@
namespace StabilityMatrix.Core.Attributes;
[AttributeUsage(AttributeTargets.Property)]
public class BoolStringMemberAttribute(string trueString, string falseString) : Attribute
{
public string TrueString { get; } = trueString;
public string FalseString { get; } = falseString;
}

100
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -1,5 +1,8 @@
using System.Diagnostics.CodeAnalysis;
using System.Drawing;
using System.Runtime.Serialization;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database;
@ -68,7 +71,21 @@ public class ComfyNodeBuilder
};
}
public static NamedComfyNode<LatentNodeConnection> KSampler(
public record KSampler : ComfyTypedNodeBase<LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
public required ulong Seed { get; init; }
public required int Steps { get; init; }
public required double Cfg { get; init; }
public required string SamplerName { get; init; }
public required string Scheduler { get; init; }
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required LatentNodeConnection LatentImage { get; init; }
public required double Denoise { get; init; }
}
/*public static NamedComfyNode<LatentNodeConnection> KSampler(
string name,
ModelNodeConnection model,
ulong seed,
@ -99,9 +116,30 @@ public class ComfyNodeBuilder
["denoise"] = denoise
}
};
}*/
public record KSamplerAdvanced : ComfyTypedNodeBase<LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
[BoolStringMember("enable", "disable")]
public required bool AddNoise { get; init; }
public required ulong NoiseSeed { get; init; }
public required int Steps { get; init; }
public required double Cfg { get; init; }
public required string Sampler { get; init; }
public required string Scheduler { get; init; }
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required LatentNodeConnection LatentImage { get; init; }
public required int StartAtStep { get; init; }
public required int EndAtStep { get; init; }
[BoolStringMember("enable", "disable")]
public bool ReturnWithLeftoverNoise { get; init; }
}
public static NamedComfyNode<LatentNodeConnection> KSamplerAdvanced(
/*public static NamedComfyNode<LatentNodeConnection> KSamplerAdvanced(
string name,
ModelNodeConnection model,
bool addNoise,
@ -138,25 +176,13 @@ public class ComfyNodeBuilder
["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable"
}
};
}
}*/
public static NamedComfyNode<LatentNodeConnection> EmptyLatentImage(
string name,
int batchSize,
int height,
int width
)
public record EmptyLatentImage : ComfyTypedNodeBase<LatentNodeConnection>
{
return new NamedComfyNode<LatentNodeConnection>(name)
{
ClassType = "EmptyLatentImage",
Inputs = new Dictionary<string, object?>
{
["batch_size"] = batchSize,
["height"] = height,
["width"] = width,
}
};
public required int BatchSize { get; init; }
public required int Height { get; init; }
public required int Width { get; init; }
}
public static NamedComfyNode<LatentNodeConnection> LatentFromBatch(
@ -264,27 +290,20 @@ public class ComfyNodeBuilder
public required string CkptName { get; init; }
}
public static NamedComfyNode<ModelNodeConnection> FreeU(
string name,
ModelNodeConnection model,
double b1,
double b2,
double s1,
double s2
)
public record FreeU : ComfyTypedNodeBase<ModelNodeConnection>
{
return new NamedComfyNode<ModelNodeConnection>(name)
{
ClassType = "FreeU",
Inputs = new Dictionary<string, object?>
{
["model"] = model.Data,
["b1"] = b1,
["b2"] = b2,
["s1"] = s1,
["s2"] = s2
}
};
public required ModelNodeConnection Model { get; init; }
public required double B1 { get; init; }
public required double B2 { get; init; }
public required double S1 { get; init; }
public required double S2 { get; init; }
}
[SuppressMessage("ReSharper", "InconsistentNaming")]
public record CLIPTextEncode : ComfyTypedNodeBase<ConditioningNodeConnection>
{
public required ClipNodeConnection Clip { get; init; }
public required string Text { get; init; }
}
public static NamedComfyNode<ConditioningNodeConnection> ClipTextEncode(
@ -818,6 +837,9 @@ public class ComfyNodeBuilder
public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; }
public ComfySampler? PrimarySampler { get; set; }
public ComfyScheduler? PrimaryScheduler { get; set; }
public List<NamedComfyNode> OutputNodes { get; } = new();
public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name);

18
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs

@ -1,5 +1,6 @@
using System.Reflection;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using Yoh.Text.Json.NamingPolicies;
@ -32,6 +33,23 @@ public abstract record ComfyTypedNodeBase
property.GetCustomAttribute<JsonPropertyNameAttribute>()?.Name
?? namingPolicy.ConvertName(property.Name);
// If theres a BoolStringMember attribute, convert to one of the strings
if (property.GetCustomAttribute<BoolStringMemberAttribute>() is { } converter)
{
if (value is bool boolValue)
{
inputs.Add(key, boolValue ? converter.TrueString : converter.FalseString);
}
else
{
throw new InvalidOperationException(
$"Property {property.Name} is not a bool, but has a BoolStringMember attribute"
);
}
continue;
}
// For connection types, use data property
if (value is NodeConnectionBase connection)
{

Loading…
Cancel
Save