Browse Source

Move inference logic to ApplyStep

pull/333/head
Ionite 12 months ago
parent
commit
5bc1715a49
No known key found for this signature in database
  1. 307
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 162
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  3. 14
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  4. 52
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs
  5. 48
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs
  6. 118
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  7. 138
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
  8. 8
      StabilityMatrix.Core/Attributes/BoolStringMemberAttribute.cs
  9. 100
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  10. 18
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs

307
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -1,40 +1,38 @@
using System; using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.Drawing; using System.Drawing;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.Extensions; namespace StabilityMatrix.Avalonia.Extensions;
public static class ComfyNodeBuilderExtensions public static class ComfyNodeBuilderExtensions
{ {
public static void SetupLatentSource( public static void SetupEmptyLatentSource(
this ComfyNodeBuilder builder, this ComfyNodeBuilder builder,
BatchSizeCardViewModel batchSizeCardViewModel, int width,
SamplerCardViewModel samplerCardViewModel int height,
int batchSize = 1,
int? batchIndex = null
) )
{ {
var emptyLatent = builder.Nodes.AddNamedNode( var emptyLatent = builder.Nodes.AddTypedNode(
ComfyNodeBuilder.EmptyLatentImage( new ComfyNodeBuilder.EmptyLatentImage
"EmptyLatentImage", {
batchSizeCardViewModel.BatchSize, Name = "EmptyLatentImage",
samplerCardViewModel.Height, BatchSize = batchSize,
samplerCardViewModel.Width Height = height,
) Width = width
}
); );
builder.Connections.Primary = emptyLatent.Output; builder.Connections.Primary = emptyLatent.Output;
builder.Connections.PrimarySize = new Size( builder.Connections.PrimarySize = new Size(width, height);
samplerCardViewModel.Width,
samplerCardViewModel.Height
);
// If batch index is selected, add a LatentFromBatch // If batch index is selected, add a LatentFromBatch
if (batchSizeCardViewModel.IsBatchIndexEnabled) if (batchIndex is not null)
{ {
builder.Connections.Primary = builder.Nodes builder.Connections.Primary = builder.Nodes
.AddNamedNode( .AddNamedNode(
@ -42,7 +40,7 @@ public static class ComfyNodeBuilderExtensions
"LatentFromBatch", "LatentFromBatch",
builder.GetPrimaryAsLatent(), builder.GetPrimaryAsLatent(),
// remote expects a 0-based index, vm is 1-based // remote expects a 0-based index, vm is 1-based
batchSizeCardViewModel.BatchIndex - 1, batchIndex.Value - 1,
1 1
) )
) )
@ -50,256 +48,63 @@ public static class ComfyNodeBuilderExtensions
} }
} }
public static void SetupBaseSampler( public static void SetupImageLatentSource(
this ComfyNodeBuilder builder, this ComfyNodeBuilder builder,
SamplerCardViewModel samplerCardViewModel, BatchSizeCardViewModel batchSizeCardViewModel,
PromptCardViewModel promptCardViewModel, SamplerCardViewModel samplerCardViewModel
ModelCardViewModel modelCardViewModel,
IModelIndexService modelIndexService,
Action<ComfyNodeBuilder>? postModelLoad = null
) )
{ {
/*// Load base checkpoint var emptyLatent = builder.Nodes.AddTypedNode(
var checkpointLoader = builder.Nodes.AddNamedNode( new ComfyNodeBuilder.EmptyLatentImage
ComfyNodeBuilder.CheckpointLoaderSimple(
"CheckpointLoader",
modelCardViewModel.SelectedModel?.FileName
?? throw new NullReferenceException("Model not selected")
)
);
builder.Connections.BaseModel = checkpointLoader.GetOutput<ModelNodeConnection>(0);
builder.Connections.BaseClip = checkpointLoader.GetOutput<ClipNodeConnection>(1);
builder.Connections.BaseVAE = checkpointLoader.GetOutput<VAENodeConnection>(2);
builder.Connections.PrimaryVAE = builder.Connections.BaseVAE;
// Run post model load action
postModelLoad?.Invoke(builder);*/
// Load prompts
var prompt = promptCardViewModel.GetPrompt();
prompt.Process();
var negativePrompt = promptCardViewModel.GetNegativePrompt();
negativePrompt.Process();
// If need to load loras, add a group
if (prompt.ExtraNetworks.Count > 0)
{
// Convert to local file names
var lorasGroup = builder.Group_LoraLoadMany(
"Loras",
builder.Connections.BaseModel,
builder.Connections.BaseClip,
prompt.GetExtraNetworksAsLocalModels(modelIndexService)
);
// Set as source
builder.Connections.BaseModel = lorasGroup.Output1;
builder.Connections.BaseClip = lorasGroup.Output2;
}
// Clips
var positiveClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"PositiveCLIP",
builder.Connections.BaseClip,
prompt.ProcessedText
)
);
var negativeClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"NegativeCLIP",
builder.Connections.BaseClip,
negativePrompt.ProcessedText
)
);
builder.Connections.BaseConditioning = positiveClip.Output;
builder.Connections.BaseNegativeConditioning = negativeClip.Output;
// Apply sampler addons (FreeU / ControlNet) to model and conditioning
var samplerStepArgs = new ModuleApplyStepEventArgs
{
Builder = builder,
Temp =
{ {
Model = builder.Connections.BaseModel, Name = "EmptyLatentImage",
Conditioning = (positiveClip.Output, negativeClip.Output) BatchSize = batchSizeCardViewModel.BatchSize,
Height = samplerCardViewModel.Height,
Width = samplerCardViewModel.Width
} }
};
samplerCardViewModel.ApplyStep(samplerStepArgs);
var model = samplerStepArgs.Temp.Model;
var conditioning = samplerStepArgs.Temp.Conditioning;
// Primary latent encoding vae
var vae =
builder.Connections.PrimaryVAE
?? builder.Connections.BaseVAE
?? throw new ValidationException("No Primary or Base VAE");
var latent = builder.GetPrimaryAsLatent(vae);
// Add base sampler (without refiner)
if (
modelCardViewModel
is not { IsRefinerSelectionEnabled: true, SelectedRefiner.IsDefault: false }
)
{
var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSampler(
"Sampler",
model,
builder.Connections.Seed,
samplerCardViewModel.Steps,
samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Scheduler not selected"),
conditioning.Positive,
conditioning.Negative,
latent,
samplerCardViewModel.DenoiseStrength
)
);
builder.Connections.Primary = sampler.Output;
}
// Add base sampler (with refiner)
else
{
// Total steps is the sum of the base and refiner steps
var totalSteps = samplerCardViewModel.Steps + samplerCardViewModel.RefinerSteps;
var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSamplerAdvanced(
"Sampler",
model,
true,
builder.Connections.Seed,
totalSteps,
samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Sampler not selected"),
conditioning.Positive,
conditioning.Negative,
latent,
0,
samplerCardViewModel.Steps,
true
)
);
builder.Connections.Primary = sampler.Output;
}
}
public static void SetupRefinerSampler(
this ComfyNodeBuilder builder,
SamplerCardViewModel samplerCardViewModel,
PromptCardViewModel promptCardViewModel,
ModelCardViewModel modelCardViewModel,
IModelIndexService modelIndexService,
Action<ComfyNodeBuilder>? postModelLoad = null
)
{
/*// Load refiner checkpoint
var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple(
"Refiner_CheckpointLoader",
modelCardViewModel.SelectedRefiner?.RelativePath
?? throw new NullReferenceException("Model not selected")
)
); );
builder.Connections.RefinerModel = checkpointLoader.GetOutput<ModelNodeConnection>(0); builder.Connections.Primary = emptyLatent.Output;
builder.Connections.RefinerClip = checkpointLoader.GetOutput<ClipNodeConnection>(1); builder.Connections.PrimarySize = new Size(
builder.Connections.RefinerVAE = checkpointLoader.GetOutput<VAENodeConnection>(2); samplerCardViewModel.Width,
builder.Connections.PrimaryVAE = builder.Connections.RefinerVAE; samplerCardViewModel.Height
);
// Run post model load action
postModelLoad?.Invoke(builder);*/
// Load prompts
var prompt = promptCardViewModel.GetPrompt();
prompt.Process();
var negativePrompt = promptCardViewModel.GetNegativePrompt();
negativePrompt.Process();
// If need to load loras, add a group // If batch index is selected, add a LatentFromBatch
if (prompt.ExtraNetworks.Count > 0) if (batchSizeCardViewModel.IsBatchIndexEnabled)
{ {
// Convert to local file names builder.Connections.Primary = builder.Nodes
var lorasGroup = builder.Group_LoraLoadMany( .AddNamedNode(
"Refiner_Loras", ComfyNodeBuilder.LatentFromBatch(
builder.Connections.RefinerModel, "LatentFromBatch",
builder.Connections.RefinerClip, builder.GetPrimaryAsLatent(),
prompt.GetExtraNetworksAsLocalModels(modelIndexService) // remote expects a 0-based index, vm is 1-based
); batchSizeCardViewModel.BatchIndex - 1,
1
// Set as source )
builder.Connections.RefinerModel = lorasGroup.Output1; )
builder.Connections.RefinerClip = lorasGroup.Output2; .Output;
} }
// Clips
var positiveClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"Refiner_PositiveCLIP",
builder.Connections.RefinerClip,
prompt.ProcessedText
)
);
var negativeClip = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.ClipTextEncode(
"Refiner_NegativeCLIP",
builder.Connections.RefinerClip,
negativePrompt.ProcessedText
)
);
builder.Connections.RefinerConditioning = positiveClip.Output;
builder.Connections.RefinerNegativeConditioning = negativeClip.Output;
// Add refiner sampler
// Total steps is the sum of the base and refiner steps
var totalSteps = samplerCardViewModel.Steps + samplerCardViewModel.RefinerSteps;
var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSamplerAdvanced(
"Refiner_Sampler",
builder.Connections.RefinerModel,
false,
builder.Connections.Seed,
totalSteps,
samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Sampler not selected"),
positiveClip.Output,
negativeClip.Output,
builder.GetPrimaryAsLatent(),
samplerCardViewModel.Steps,
totalSteps,
false
)
);
builder.Connections.Primary = sampler.Output;
} }
public static string SetupOutputImage(this ComfyNodeBuilder builder) public static string SetupOutputImage(this ComfyNodeBuilder builder)
{ {
var previewImage = builder.Nodes.AddTypedNode( if (builder.Connections.Primary is null)
new ComfyNodeBuilder.PreviewImage throw new ArgumentException("No Primary");
{
Name = "SaveImage", var image = builder.Connections.Primary.Match(
Images = builder.GetPrimaryAsImage( _ =>
builder.GetPrimaryAsImage(
builder.Connections.PrimaryVAE builder.Connections.PrimaryVAE
?? builder.Connections.RefinerVAE ?? builder.Connections.RefinerVAE
?? builder.Connections.BaseVAE ?? builder.Connections.BaseVAE
) ?? throw new ArgumentException("No Primary, Refiner, or Base VAE")
} ),
image => image
);
var previewImage = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.PreviewImage { Name = "SaveImage", Images = image }
); );
builder.Connections.OutputNodes.Add(previewImage); builder.Connections.OutputNodes.Add(previewImage);

162
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -148,37 +148,11 @@ public class InferenceTextToImageViewModel
StackCardViewModel = vmFactory.Get<StackCardViewModel>(); StackCardViewModel = vmFactory.Get<StackCardViewModel>();
StackCardViewModel.AddCards( StackCardViewModel.AddCards(
new LoadableViewModelBase[] ModelCardViewModel,
{ SamplerCardViewModel,
ModelCardViewModel, ModulesCardViewModel,
SamplerCardViewModel, SeedCardViewModel,
ModulesCardViewModel, BatchSizeCardViewModel
/*// Free U
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "FreeU";
stackExpander.AddCards(new LoadableViewModelBase[] { FreeUCardViewModel });
}),
// Hires Fix
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Hires Fix";
stackExpander.AddCards(
new LoadableViewModelBase[]
{
HiresUpscalerCardViewModel,
HiresSamplerCardViewModel
}
);
}),
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Upscale";
stackExpander.AddCards(new LoadableViewModelBase[] { UpscalerCardViewModel });
}),*/
SeedCardViewModel,
BatchSizeCardViewModel,
}
); );
// When refiner is provided in model card, enable for sampler // When refiner is provided in model card, enable for sampler
@ -196,128 +170,36 @@ public class InferenceTextToImageViewModel
{ {
base.BuildPrompt(args); base.BuildPrompt(args);
using var _ = CodeTimer.StartDebug();
var builder = args.Builder; var builder = args.Builder;
var nodes = builder.Nodes;
if (args.SeedOverride is { } seed) builder.Connections.Seed = args.SeedOverride switch
{
builder.Connections.Seed = Convert.ToUInt64(seed);
}
else
{ {
builder.Connections.Seed = Convert.ToUInt64(SeedCardViewModel.Seed); { } seed => Convert.ToUInt64(seed),
} _ => Convert.ToUInt64(SeedCardViewModel.Seed)
};
// Load models // Load models
ModelCardViewModel.ApplyStep(args); ModelCardViewModel.ApplyStep(args);
// Setup empty latent // Setup empty latent
builder.SetupLatentSource(BatchSizeCardViewModel, SamplerCardViewModel); builder.SetupEmptyLatentSource(
SamplerCardViewModel.Width,
// Setup base sampling stage SamplerCardViewModel.Height,
builder.SetupBaseSampler( BatchSizeCardViewModel.BatchSize,
SamplerCardViewModel, BatchSizeCardViewModel.IsBatchIndexEnabled ? BatchSizeCardViewModel.BatchIndex : null
PromptCardViewModel,
ModelCardViewModel,
modelIndexService
); );
// Setup refiner stage if enabled // Prompts and loras
if ( PromptCardViewModel.ApplyStep(args);
ModelCardViewModel is
{ IsRefinerSelectionEnabled: true, SelectedRefiner.IsDefault: false }
)
{
builder.SetupRefinerSampler(
SamplerCardViewModel,
PromptCardViewModel,
ModelCardViewModel,
modelIndexService
);
}
// Override with custom VAE if enabled // Setup Sampler and Refiner if enabled
/*if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false }) SamplerCardViewModel.ApplyStep(args);
{
var customVaeLoader = nodes.AddNamedNode(
ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath)
);
builder.Connections.PrimaryVAE = customVaeLoader.Output; // Hires fix if enabled
}*/ foreach (var module in ModulesCardViewModel.Cards.OfType<ModuleBase>())
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
/*if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
{ {
// Get new latent size module.ApplyStep(args);
var hiresSize = builder.Connections.PrimarySize.WithScale( }
HiresUpscalerCardViewModel.Scale
);
// Select between latent upscale and normal upscale based on the upscale method
var selectedUpscaler = HiresUpscalerCardViewModel.SelectedUpscaler!.Value;
// If upscaler selected, upscale latent image first
if (selectedUpscaler.Type != ComfyUpscalerType.None)
{
builder.Connections.Primary = builder.Group_Upscale(
"HiresFix",
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
selectedUpscaler,
hiresSize.Width,
hiresSize.Height
);
}
// Use refiner model if set, or base if not
var hiresSampler = nodes.AddNamedNode(
ComfyNodeBuilder.KSampler(
"HiresSampler",
builder.Connections.GetRefinerOrBaseModel(),
builder.Connections.Seed,
HiresSamplerCardViewModel.Steps,
HiresSamplerCardViewModel.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler
HiresSamplerCardViewModel.SelectedSampler
?? SamplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
HiresSamplerCardViewModel.SelectedScheduler
?? SamplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Scheduler not selected"),
builder.Connections.GetRefinerOrBaseConditioning(),
builder.Connections.GetRefinerOrBaseNegativeConditioning(),
builder.GetPrimaryAsLatent(),
HiresSamplerCardViewModel.DenoiseStrength
)
);
// Set as primary
builder.Connections.Primary = hiresSampler.Output;
builder.Connections.PrimarySize = hiresSize;
}*/
// If upscale is enabled, add another upscale group
/*if (IsUpscaleEnabled)
{
var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
var upscaleResult = builder.Group_Upscale(
"PostUpscale",
builder.Connections.Primary!,
builder.Connections.PrimaryVAE!,
UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width,
upscaleSize.Height
);
builder.Connections.Primary = upscaleResult;
builder.Connections.PrimarySize = upscaleSize;
}*/
builder.SetupOutputImage(); builder.SetupOutputImage();
} }

14
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -129,13 +129,6 @@ public partial class ModelCardViewModel
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
} }
internal class ModelCardModel
{
public string? SelectedModelName { get; init; }
public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; }
}
/// <inheritdoc /> /// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters) public void LoadStateFromParameters(GenerationParameters parameters)
{ {
@ -182,4 +175,11 @@ public partial class ModelCardViewModel
ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256 ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256
}; };
} }
internal class ModelCardModel
{
public string? SelectedModelName { get; init; }
public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; }
}
} }

52
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs

@ -1,5 +1,4 @@
using System; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
@ -26,20 +25,41 @@ public class FreeUModule : ModuleBase
{ {
var card = GetCard<FreeUCardViewModel>(); var card = GetCard<FreeUCardViewModel>();
e.Temp.Model = e.Nodes // Currently applies to both base and refiner model
.AddNamedNode( // TODO: Add option to apply to either base or refiner
ComfyNodeBuilder.FreeU(
e.Nodes.GetUniqueName("FreeU"), if (e.Builder.Connections.BaseModel is not null)
e.Temp.Model {
?? throw new ArgumentException( e.Builder.Connections.BaseModel = e.Nodes
"Temp.Model not set on ModuleApplyStepEventArgs" .AddTypedNode(
), new ComfyNodeBuilder.FreeU
card.B1, {
card.B2, Name = e.Nodes.GetUniqueName("FreeU"),
card.S1, Model = e.Builder.Connections.BaseModel,
card.S2 B1 = card.B1,
B2 = card.B2,
S1 = card.S1,
S2 = card.S2
}
)
.Output;
}
if (e.Builder.Connections.RefinerModel is not null)
{
e.Builder.Connections.RefinerModel = e.Nodes
.AddTypedNode(
new ComfyNodeBuilder.FreeU
{
Name = e.Nodes.GetUniqueName("Refiner_FreeU"),
Model = e.Builder.Connections.RefinerModel,
B1 = card.B1,
B2 = card.B2,
S1 = card.S1,
S2 = card.S2
}
) )
) .Output;
.Output; }
} }
} }

48
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs

@ -1,4 +1,5 @@
using System.ComponentModel.DataAnnotations; using System;
using System.ComponentModel.DataAnnotations;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
@ -46,34 +47,35 @@ public class HiresFixModule : ModuleBase
{ {
builder.Connections.Primary = builder.Group_Upscale( builder.Connections.Primary = builder.Group_Upscale(
"HiresFix", "HiresFix",
builder.Connections.Primary!, builder.Connections.Primary ?? throw new ArgumentException("No Primary"),
builder.Connections.PrimaryVAE!, builder.Connections.PrimaryVAE ?? throw new ArgumentException("No PrimaryVAE"),
selectedUpscaler, selectedUpscaler,
hiresSize.Width, hiresSize.Width,
hiresSize.Height hiresSize.Height
); );
} }
// Use refiner model if set, or base if not var hiresSampler = builder.Nodes.AddTypedNode(
var hiresSampler = builder.Nodes.AddNamedNode( new ComfyNodeBuilder.KSampler
ComfyNodeBuilder.KSampler( {
builder.Nodes.GetUniqueName("HiresFix_Sampler"), Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"),
builder.Connections.GetRefinerOrBaseModel(), Model = builder.Connections.GetRefinerOrBaseModel(),
builder.Connections.Seed, Seed = builder.Connections.Seed,
samplerCard.Steps, Steps = samplerCard.Steps,
samplerCard.CfgScale, Cfg = samplerCard.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler SamplerName =
samplerCard.SelectedSampler samplerCard.SelectedSampler?.Name
?? samplerCard.SelectedSampler ?? e.Builder.Connections.PrimarySampler?.Name
?? throw new ValidationException("Sampler not selected"), ?? throw new ArgumentException("No PrimarySampler"),
samplerCard.SelectedScheduler Scheduler =
?? samplerCard.SelectedScheduler samplerCard.SelectedScheduler?.Name
?? throw new ValidationException("Scheduler not selected"), ?? e.Builder.Connections.PrimaryScheduler?.Name
builder.Connections.GetRefinerOrBaseConditioning(), ?? throw new ArgumentException("No PrimaryScheduler"),
builder.Connections.GetRefinerOrBaseNegativeConditioning(), Positive = builder.Connections.GetRefinerOrBaseConditioning(),
builder.GetPrimaryAsLatent(), Negative = builder.Connections.GetRefinerOrBaseNegativeConditioning(),
samplerCard.DenoiseStrength LatentImage = builder.GetPrimaryAsLatent(),
) Denoise = samplerCard.DenoiseStrength
}
); );
// Set as primary // Set as primary

118
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,4 +1,6 @@
using System.Text; using System;
using System.Linq;
using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -16,6 +18,7 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -23,7 +26,10 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))] [View(typeof(PromptCard))]
[ManagedService] [ManagedService]
[Transient] [Transient]
public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState public partial class PromptCardViewModel
: LoadableViewModelBase,
IParametersLoadableState,
IComfyStep
{ {
private readonly IModelIndexService modelIndexService; private readonly IModelIndexService modelIndexService;
@ -64,6 +70,114 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
); );
} }
/// <summary>
/// Applies the prompt step.
/// Requires:
/// <list type="number">
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseModel"/></item>
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseClip"/></item>
/// </list>
/// Provides:
/// <list type="number">
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseConditioning"/></item>
/// <item><see cref="ComfyNodeBuilder.NodeBuilderConnections.BaseNegativeConditioning"/></item>
/// </list>
/// </summary>
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Load prompts
var positivePrompt = GetPrompt();
positivePrompt.Process();
var negativePrompt = GetNegativePrompt();
negativePrompt.Process();
// If need to load loras, add a group
if (positivePrompt.ExtraNetworks.Count > 0)
{
var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList();
// Add group to load loras onto model and clip in series
var lorasGroup = e.Builder.Group_LoraLoadMany(
"Loras",
e.Builder.Connections.BaseModel ?? throw new ArgumentException("BaseModel is null"),
e.Builder.Connections.BaseClip ?? throw new ArgumentException("BaseClip is null"),
loras
);
// Set last outputs as base model and clip
e.Builder.Connections.BaseModel = lorasGroup.Output1;
e.Builder.Connections.BaseClip = lorasGroup.Output2;
// Refiner loras
if (e.Builder.Connections.RefinerModel is not null)
{
// Add group to load loras onto refiner model and clip in series
var lorasGroupRefiner = e.Builder.Group_LoraLoadMany(
"Refiner_Loras",
e.Builder.Connections.RefinerModel
?? throw new ArgumentException("RefinerModel is null"),
e.Builder.Connections.RefinerClip
?? throw new ArgumentException("RefinerClip is null"),
loras
);
// Set last outputs as refiner model and clip
e.Builder.Connections.RefinerModel = lorasGroupRefiner.Output1;
e.Builder.Connections.RefinerClip = lorasGroupRefiner.Output2;
}
}
// Clips
var positiveClip = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "PositiveCLIP",
Clip = e.Builder.Connections.BaseClip!,
Text = positivePrompt.ProcessedText
}
);
var negativeClip = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "NegativeCLIP",
Clip = e.Builder.Connections.BaseClip!,
Text = negativePrompt.ProcessedText
}
);
// Set base conditioning from Clips
e.Builder.Connections.BaseConditioning = positiveClip.Output;
e.Builder.Connections.BaseNegativeConditioning = negativeClip.Output;
// Refiner Clips
if (e.Builder.Connections.RefinerModel is not null)
{
var positiveClipRefiner = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "Refiner_PositiveCLIP",
Clip =
e.Builder.Connections.RefinerClip
?? throw new ArgumentException("RefinerClip is null"),
Text = positivePrompt.ProcessedText
}
);
var negativeClipRefiner = e.Builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = "Refiner_NegativeCLIP",
Clip =
e.Builder.Connections.RefinerClip
?? throw new ArgumentException("RefinerClip is null"),
Text = negativePrompt.ProcessedText
}
);
// Set refiner conditioning from Clips
e.Builder.Connections.RefinerConditioning = positiveClipRefiner.Output;
e.Builder.Connections.RefinerNegativeConditioning = negativeClipRefiner.Output;
}
}
/// <summary> /// <summary>
/// Gets the tokenized Prompt for given text and caches it /// Gets the tokenized Prompt for given text and caches it
/// </summary> /// </summary>

138
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -1,4 +1,6 @@
using System.ComponentModel.DataAnnotations; using System;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Linq; using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
@ -12,6 +14,7 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -75,6 +78,8 @@ public partial class SamplerCardViewModel
[JsonIgnore] [JsonIgnore]
public IInferenceClientManager ClientManager { get; } public IInferenceClientManager ClientManager { get; }
private int TotalSteps => Steps + RefinerSteps;
public SamplerCardViewModel( public SamplerCardViewModel(
IInferenceClientManager clientManager, IInferenceClientManager clientManager,
ServiceManager<ViewModelBase> vmFactory ServiceManager<ViewModelBase> vmFactory
@ -102,6 +107,137 @@ public partial class SamplerCardViewModel
/// <inheritdoc /> /// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e) public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Apply steps from our addons
ApplyAddonSteps(e);
// If "Sampler" is not yet a node, do initial setup
// otherwise do hires setup
if (!e.Nodes.ContainsKey("Sampler"))
{
ApplyStepsInitialSampler(e);
}
else
{
ApplyStepsAdditionalSampler(e);
}
}
private void ApplyStepsInitialSampler(ModuleApplyStepEventArgs e)
{
// Get primary or base VAE
var vae =
e.Builder.Connections.PrimaryVAE
?? e.Builder.Connections.BaseVAE
?? throw new ArgumentException("No Primary or Base VAE");
// Get primary as latent using vae
var primaryLatent = e.Builder.GetPrimaryAsLatent(vae);
// Set primary sampler and scheduler
e.Builder.Connections.PrimarySampler =
SelectedSampler ?? throw new ValidationException("Sampler not selected");
e.Builder.Connections.PrimaryScheduler =
SelectedScheduler ?? throw new ValidationException("Scheduler not selected");
// Use KSampler if no refiner, otherwise need KSamplerAdvanced
if (e.Builder.Connections.RefinerModel is null)
{
// No refiner
var sampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSampler
{
Name = "Sampler",
Model =
e.Builder.Connections.BaseModel
?? throw new ArgumentException("No BaseModel"),
Seed = e.Builder.Connections.Seed,
SamplerName = e.Builder.Connections.PrimarySampler?.Name!,
Scheduler = e.Builder.Connections.PrimaryScheduler?.Name!,
Steps = Steps,
Cfg = CfgScale,
Positive =
e.Builder.Connections.BaseConditioning
?? throw new ArgumentException("No BaseConditioning"),
Negative =
e.Builder.Connections.BaseNegativeConditioning
?? throw new ArgumentException("No BaseNegativeConditioning"),
LatentImage = primaryLatent,
Denoise = DenoiseStrength,
}
);
e.Builder.Connections.Primary = sampler.Output;
}
else
{
// Advanced base sampler for refiner
var sampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSamplerAdvanced
{
Name = "Sampler",
Model =
e.Builder.Connections.BaseModel
?? throw new ArgumentException("No BaseModel"),
AddNoise = true,
NoiseSeed = e.Builder.Connections.Seed,
Steps = TotalSteps,
Cfg = CfgScale,
Sampler = e.Builder.Connections.PrimarySampler?.Name!,
Scheduler = e.Builder.Connections.PrimaryScheduler?.Name!,
Positive =
e.Builder.Connections.BaseConditioning
?? throw new ArgumentException("No BaseConditioning"),
Negative =
e.Builder.Connections.BaseNegativeConditioning
?? throw new ArgumentException("No BaseNegativeConditioning"),
LatentImage = primaryLatent,
StartAtStep = 0,
EndAtStep = TotalSteps,
ReturnWithLeftoverNoise = true
}
);
// Add refiner sampler
var refinerSampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSamplerAdvanced
{
Name = "Refiner_Sampler",
Model =
e.Builder.Connections.RefinerModel
?? throw new ArgumentException("No RefinerModel"),
AddNoise = false,
NoiseSeed = e.Builder.Connections.Seed,
Steps = TotalSteps,
Cfg = CfgScale,
Sampler = e.Builder.Connections.PrimarySampler?.Name!,
Scheduler = e.Builder.Connections.PrimaryScheduler?.Name!,
Positive =
e.Builder.Connections.RefinerConditioning
?? throw new ArgumentException("No RefinerConditioning"),
Negative =
e.Builder.Connections.RefinerNegativeConditioning
?? throw new ArgumentException("No RefinerNegativeConditioning"),
// Connect to previous sampler
LatentImage = sampler.Output,
StartAtStep = Steps,
EndAtStep = TotalSteps,
ReturnWithLeftoverNoise = false
}
);
e.Builder.Connections.Primary = refinerSampler.Output;
}
}
private void ApplyStepsAdditionalSampler(ModuleApplyStepEventArgs e) { }
/// <summary>
/// Applies each step of our addons
/// </summary>
/// <param name="e"></param>
private void ApplyAddonSteps(ModuleApplyStepEventArgs e)
{ {
// Apply steps from our modules // Apply steps from our modules
foreach (var module in ModulesCardViewModel.Cards.Cast<ModuleBase>()) foreach (var module in ModulesCardViewModel.Cards.Cast<ModuleBase>())

8
StabilityMatrix.Core/Attributes/BoolStringMemberAttribute.cs

@ -0,0 +1,8 @@
namespace StabilityMatrix.Core.Attributes;
[AttributeUsage(AttributeTargets.Property)]
public class BoolStringMemberAttribute(string trueString, string falseString) : Attribute
{
public string TrueString { get; } = trueString;
public string FalseString { get; } = falseString;
}

100
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -1,5 +1,8 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Drawing; using System.Drawing;
using System.Runtime.Serialization;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.Database;
@ -68,7 +71,21 @@ public class ComfyNodeBuilder
}; };
} }
public static NamedComfyNode<LatentNodeConnection> KSampler( public record KSampler : ComfyTypedNodeBase<LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
public required ulong Seed { get; init; }
public required int Steps { get; init; }
public required double Cfg { get; init; }
public required string SamplerName { get; init; }
public required string Scheduler { get; init; }
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required LatentNodeConnection LatentImage { get; init; }
public required double Denoise { get; init; }
}
/*public static NamedComfyNode<LatentNodeConnection> KSampler(
string name, string name,
ModelNodeConnection model, ModelNodeConnection model,
ulong seed, ulong seed,
@ -99,9 +116,30 @@ public class ComfyNodeBuilder
["denoise"] = denoise ["denoise"] = denoise
} }
}; };
}*/
public record KSamplerAdvanced : ComfyTypedNodeBase<LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
[BoolStringMember("enable", "disable")]
public required bool AddNoise { get; init; }
public required ulong NoiseSeed { get; init; }
public required int Steps { get; init; }
public required double Cfg { get; init; }
public required string Sampler { get; init; }
public required string Scheduler { get; init; }
public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; }
public required LatentNodeConnection LatentImage { get; init; }
public required int StartAtStep { get; init; }
public required int EndAtStep { get; init; }
[BoolStringMember("enable", "disable")]
public bool ReturnWithLeftoverNoise { get; init; }
} }
public static NamedComfyNode<LatentNodeConnection> KSamplerAdvanced( /*public static NamedComfyNode<LatentNodeConnection> KSamplerAdvanced(
string name, string name,
ModelNodeConnection model, ModelNodeConnection model,
bool addNoise, bool addNoise,
@ -138,25 +176,13 @@ public class ComfyNodeBuilder
["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable" ["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable"
} }
}; };
} }*/
public static NamedComfyNode<LatentNodeConnection> EmptyLatentImage( public record EmptyLatentImage : ComfyTypedNodeBase<LatentNodeConnection>
string name,
int batchSize,
int height,
int width
)
{ {
return new NamedComfyNode<LatentNodeConnection>(name) public required int BatchSize { get; init; }
{ public required int Height { get; init; }
ClassType = "EmptyLatentImage", public required int Width { get; init; }
Inputs = new Dictionary<string, object?>
{
["batch_size"] = batchSize,
["height"] = height,
["width"] = width,
}
};
} }
public static NamedComfyNode<LatentNodeConnection> LatentFromBatch( public static NamedComfyNode<LatentNodeConnection> LatentFromBatch(
@ -264,27 +290,20 @@ public class ComfyNodeBuilder
public required string CkptName { get; init; } public required string CkptName { get; init; }
} }
public static NamedComfyNode<ModelNodeConnection> FreeU( public record FreeU : ComfyTypedNodeBase<ModelNodeConnection>
string name,
ModelNodeConnection model,
double b1,
double b2,
double s1,
double s2
)
{ {
return new NamedComfyNode<ModelNodeConnection>(name) public required ModelNodeConnection Model { get; init; }
{ public required double B1 { get; init; }
ClassType = "FreeU", public required double B2 { get; init; }
Inputs = new Dictionary<string, object?> public required double S1 { get; init; }
{ public required double S2 { get; init; }
["model"] = model.Data, }
["b1"] = b1,
["b2"] = b2, [SuppressMessage("ReSharper", "InconsistentNaming")]
["s1"] = s1, public record CLIPTextEncode : ComfyTypedNodeBase<ConditioningNodeConnection>
["s2"] = s2 {
} public required ClipNodeConnection Clip { get; init; }
}; public required string Text { get; init; }
} }
public static NamedComfyNode<ConditioningNodeConnection> ClipTextEncode( public static NamedComfyNode<ConditioningNodeConnection> ClipTextEncode(
@ -818,6 +837,9 @@ public class ComfyNodeBuilder
public VAENodeConnection? PrimaryVAE { get; set; } public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; } public Size PrimarySize { get; set; }
public ComfySampler? PrimarySampler { get; set; }
public ComfyScheduler? PrimaryScheduler { get; set; }
public List<NamedComfyNode> OutputNodes { get; } = new(); public List<NamedComfyNode> OutputNodes { get; } = new();
public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name); public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name);

18
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs

@ -1,5 +1,6 @@
using System.Reflection; using System.Reflection;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using Yoh.Text.Json.NamingPolicies; using Yoh.Text.Json.NamingPolicies;
@ -32,6 +33,23 @@ public abstract record ComfyTypedNodeBase
property.GetCustomAttribute<JsonPropertyNameAttribute>()?.Name property.GetCustomAttribute<JsonPropertyNameAttribute>()?.Name
?? namingPolicy.ConvertName(property.Name); ?? namingPolicy.ConvertName(property.Name);
// If theres a BoolStringMember attribute, convert to one of the strings
if (property.GetCustomAttribute<BoolStringMemberAttribute>() is { } converter)
{
if (value is bool boolValue)
{
inputs.Add(key, boolValue ? converter.TrueString : converter.FalseString);
}
else
{
throw new InvalidOperationException(
$"Property {property.Name} is not a bool, but has a BoolStringMember attribute"
);
}
continue;
}
// For connection types, use data property // For connection types, use data property
if (value is NodeConnectionBase connection) if (value is NodeConnectionBase connection)
{ {

Loading…
Cancel
Save