Browse Source

Implement IComfyStep apply for modules

pull/333/head
Ionite 1 year ago
parent
commit
9afe848cdb
No known key found for this signature in database
  1. 56
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 2
      StabilityMatrix.Avalonia/Models/IJsonLoadableState.cs
  3. 5
      StabilityMatrix.Avalonia/Models/ImageSource.cs
  4. 24
      StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs
  5. 17
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  6. 18
      StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs
  7. 85
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  8. 62
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  9. 52
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs
  10. 45
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs
  11. 39
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

56
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.Drawing; using System.Drawing;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
@ -58,7 +59,7 @@ public static class ComfyNodeBuilderExtensions
Action<ComfyNodeBuilder>? postModelLoad = null Action<ComfyNodeBuilder>? postModelLoad = null
) )
{ {
// Load base checkpoint /*// Load base checkpoint
var checkpointLoader = builder.Nodes.AddNamedNode( var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple( ComfyNodeBuilder.CheckpointLoaderSimple(
"CheckpointLoader", "CheckpointLoader",
@ -73,7 +74,7 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.PrimaryVAE = builder.Connections.BaseVAE; builder.Connections.PrimaryVAE = builder.Connections.BaseVAE;
// Run post model load action // Run post model load action
postModelLoad?.Invoke(builder); postModelLoad?.Invoke(builder);*/
// Load prompts // Load prompts
var prompt = promptCardViewModel.GetPrompt(); var prompt = promptCardViewModel.GetPrompt();
@ -115,6 +116,28 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.BaseConditioning = positiveClip.Output; builder.Connections.BaseConditioning = positiveClip.Output;
builder.Connections.BaseNegativeConditioning = negativeClip.Output; builder.Connections.BaseNegativeConditioning = negativeClip.Output;
// Apply sampler addons (FreeU / ControlNet) to model and conditioning
var samplerStepArgs = new ModuleApplyStepEventArgs
{
Builder = builder,
Temp =
{
Model = builder.Connections.BaseModel,
Conditioning = (positiveClip.Output, negativeClip.Output)
}
};
samplerCardViewModel.ApplyStep(samplerStepArgs);
var model = samplerStepArgs.Temp.Model;
var conditioning = samplerStepArgs.Temp.Conditioning;
// Primary latent encoding vae
var vae =
builder.Connections.PrimaryVAE
?? builder.Connections.BaseVAE
?? throw new ValidationException("No Primary or Base VAE");
var latent = builder.GetPrimaryAsLatent(vae);
// Add base sampler (without refiner) // Add base sampler (without refiner)
if ( if (
modelCardViewModel modelCardViewModel
@ -124,18 +147,17 @@ public static class ComfyNodeBuilderExtensions
var sampler = builder.Nodes.AddNamedNode( var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSampler( ComfyNodeBuilder.KSampler(
"Sampler", "Sampler",
builder.Connections.BaseModel, model,
builder.Connections.Seed, builder.Connections.Seed,
samplerCardViewModel.Steps, samplerCardViewModel.Steps,
samplerCardViewModel.CfgScale, samplerCardViewModel.CfgScale,
samplerCardViewModel.SelectedSampler samplerCardViewModel.SelectedSampler
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Scheduler not selected"),
positiveClip.Output, conditioning.Positive,
negativeClip.Output, conditioning.Negative,
builder.GetPrimaryAsLatent() latent,
?? throw new ValidationException("Latent source not set"),
samplerCardViewModel.DenoiseStrength samplerCardViewModel.DenoiseStrength
) )
); );
@ -150,7 +172,7 @@ public static class ComfyNodeBuilderExtensions
var sampler = builder.Nodes.AddNamedNode( var sampler = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.KSamplerAdvanced( ComfyNodeBuilder.KSamplerAdvanced(
"Sampler", "Sampler",
builder.Connections.BaseModel, model,
true, true,
builder.Connections.Seed, builder.Connections.Seed,
totalSteps, totalSteps,
@ -159,9 +181,9 @@ public static class ComfyNodeBuilderExtensions
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Sampler not selected"),
samplerCardViewModel.SelectedScheduler samplerCardViewModel.SelectedScheduler
?? throw new ValidationException("Sampler not selected"), ?? throw new ValidationException("Sampler not selected"),
positiveClip.Output, conditioning.Positive,
negativeClip.Output, conditioning.Negative,
builder.GetPrimaryAsLatent(), latent,
0, 0,
samplerCardViewModel.Steps, samplerCardViewModel.Steps,
true true
@ -180,7 +202,7 @@ public static class ComfyNodeBuilderExtensions
Action<ComfyNodeBuilder>? postModelLoad = null Action<ComfyNodeBuilder>? postModelLoad = null
) )
{ {
// Load refiner checkpoint /*// Load refiner checkpoint
var checkpointLoader = builder.Nodes.AddNamedNode( var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple( ComfyNodeBuilder.CheckpointLoaderSimple(
"Refiner_CheckpointLoader", "Refiner_CheckpointLoader",
@ -195,7 +217,7 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.PrimaryVAE = builder.Connections.RefinerVAE; builder.Connections.PrimaryVAE = builder.Connections.RefinerVAE;
// Run post model load action // Run post model load action
postModelLoad?.Invoke(builder); postModelLoad?.Invoke(builder);*/
// Load prompts // Load prompts
var prompt = promptCardViewModel.GetPrompt(); var prompt = promptCardViewModel.GetPrompt();
@ -272,7 +294,11 @@ public static class ComfyNodeBuilderExtensions
new ComfyNodeBuilder.PreviewImage new ComfyNodeBuilder.PreviewImage
{ {
Name = "SaveImage", Name = "SaveImage",
Images = builder.GetPrimaryAsImage() Images = builder.GetPrimaryAsImage(
builder.Connections.PrimaryVAE
?? builder.Connections.RefinerVAE
?? builder.Connections.BaseVAE
)
} }
); );

2
StabilityMatrix.Avalonia/Models/IJsonLoadableState.cs

@ -4,6 +4,8 @@ namespace StabilityMatrix.Avalonia.Models;
public interface IJsonLoadableState public interface IJsonLoadableState
{ {
void LoadStateFromJsonObject(JsonObject state, int version);
void LoadStateFromJsonObject(JsonObject state); void LoadStateFromJsonObject(JsonObject state);
JsonObject SaveStateToJsonObject(); JsonObject SaveStateToJsonObject();

5
StabilityMatrix.Avalonia/Models/ImageSource.cs

@ -128,6 +128,11 @@ public record ImageSource : IDisposable
return guid + extension; return guid + extension;
} }
public string GetHashGuidFileNameCached(string pathPrefix)
{
return Path.Combine(pathPrefix, GetHashGuidFileNameCached());
}
/// <summary> /// <summary>
/// Clears the cached bitmap /// Clears the cached bitmap
/// </summary> /// </summary>

24
StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs

@ -14,13 +14,7 @@ public class ModuleApplyStepEventArgs : EventArgs
public NodeDictionary Nodes => Builder.Nodes; public NodeDictionary Nodes => Builder.Nodes;
/// <summary> public ModuleApplyStepTemporaryArgs Temp { get; } = new();
/// Temporary conditioning apply step, used by samplers to apply control net.
/// </summary>
public (
ConditioningNodeConnection Positive,
ConditioningNodeConnection Negative
) Conditioning { get; set; }
/// <summary> /// <summary>
/// Index of the step in the pipeline. /// Index of the step in the pipeline.
@ -37,4 +31,20 @@ public class ModuleApplyStepEventArgs : EventArgs
/// </summary> /// </summary>
public IReadOnlyDictionary<Type, bool> IsEnabledOverrides { get; init; } = public IReadOnlyDictionary<Type, bool> IsEnabledOverrides { get; init; } =
new Dictionary<Type, bool>(); new Dictionary<Type, bool>();
public class ModuleApplyStepTemporaryArgs
{
/// <summary>
/// Temporary conditioning apply step, used by samplers to apply control net.
/// </summary>
public (
ConditioningNodeConnection Positive,
ConditioningNodeConnection Negative
) Conditioning { get; set; }
/// <summary>
/// Temporary model apply step, used by samplers to apply control net.
/// </summary>
public ModelNodeConnection? Model { get; set; }
}
} }

17
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -22,6 +22,7 @@ using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
@ -602,5 +603,21 @@ public abstract partial class InferenceGenerationViewModelBase
public ComfyNodeBuilder Builder { get; } = new(); public ComfyNodeBuilder Builder { get; } = new();
public GenerateOverrides Overrides { get; init; } = new(); public GenerateOverrides Overrides { get; init; } = new();
public long? SeedOverride { get; init; } public long? SeedOverride { get; init; }
public static implicit operator ModuleApplyStepEventArgs(BuildPromptEventArgs args)
{
var overrides = new Dictionary<Type, bool>();
if (args.Overrides.IsHiresFixEnabled.HasValue)
{
overrides[typeof(HiresFixModule)] = args.Overrides.IsHiresFixEnabled.Value;
}
return new ModuleApplyStepEventArgs
{
Builder = args.Builder,
IsEnabledOverrides = overrides
};
}
} }
} }

18
StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs

@ -1,4 +1,5 @@
using System; using System;
using System.ComponentModel;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using System.Text.Json; using System.Text.Json;
@ -13,11 +14,15 @@ using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
namespace StabilityMatrix.Avalonia.ViewModels.Base; namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(StackExpanderViewModel), StackExpanderViewModel.ModuleKey)] [JsonDerivedType(typeof(StackExpanderViewModel), StackExpanderViewModel.ModuleKey)]
[JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(SamplerCardViewModel), SamplerCardViewModel.ModuleKey)] [JsonDerivedType(typeof(SamplerCardViewModel), SamplerCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(ControlNetCardViewModel), ControlNetCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(FreeUModule))]
[JsonDerivedType(typeof(HiresFixModule))]
[JsonDerivedType(typeof(UpscalerModule))] [JsonDerivedType(typeof(UpscalerModule))]
[JsonDerivedType(typeof(ControlNetModule))]
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -28,10 +33,10 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
typeof(IRelayCommand) typeof(IRelayCommand)
}; };
private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors), }; private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) };
private static readonly JsonSerializerOptions SerializerOptions = private static readonly JsonSerializerOptions SerializerOptions =
new() { IgnoreReadOnlyProperties = true, }; new() { IgnoreReadOnlyProperties = true };
private static bool ShouldIgnoreProperty(PropertyInfo property) private static bool ShouldIgnoreProperty(PropertyInfo property)
{ {
@ -280,6 +285,11 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
return state; return state;
} }
public virtual void LoadStateFromJsonObject(JsonObject state, int version)
{
LoadStateFromJsonObject(state);
}
/// <summary> /// <summary>
/// Serialize a model to a JSON object. /// Serialize a model to a JSON object.
/// </summary> /// </summary>

85
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.Drawing; using System.Drawing;
using System.Linq; using System.Linq;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -209,33 +210,18 @@ public class InferenceTextToImageViewModel
builder.Connections.Seed = Convert.ToUInt64(SeedCardViewModel.Seed); builder.Connections.Seed = Convert.ToUInt64(SeedCardViewModel.Seed);
} }
// Load models
ModelCardViewModel.ApplyStep(args);
// Setup empty latent // Setup empty latent
builder.SetupLatentSource(BatchSizeCardViewModel, SamplerCardViewModel); builder.SetupLatentSource(BatchSizeCardViewModel, SamplerCardViewModel);
// Setup base stage // Setup base sampling stage
builder.SetupBaseSampler( builder.SetupBaseSampler(
SamplerCardViewModel, SamplerCardViewModel,
PromptCardViewModel, PromptCardViewModel,
ModelCardViewModel, ModelCardViewModel,
modelIndexService, modelIndexService
postModelLoad: x =>
{
if (IsFreeUEnabled)
{
builder.Connections.BaseModel = nodes
.AddNamedNode(
ComfyNodeBuilder.FreeU(
"FreeU",
x.Connections.BaseModel!,
FreeUCardViewModel.B1,
FreeUCardViewModel.B2,
FreeUCardViewModel.S1,
FreeUCardViewModel.S2
)
)
.Output;
}
}
); );
// Setup refiner stage if enabled // Setup refiner stage if enabled
@ -248,40 +234,22 @@ public class InferenceTextToImageViewModel
SamplerCardViewModel, SamplerCardViewModel,
PromptCardViewModel, PromptCardViewModel,
ModelCardViewModel, ModelCardViewModel,
modelIndexService, modelIndexService
postModelLoad: x =>
{
if (IsFreeUEnabled)
{
builder.Connections.RefinerModel = nodes
.AddNamedNode(
ComfyNodeBuilder.FreeU(
"Refiner_FreeU",
x.Connections.RefinerModel!,
FreeUCardViewModel.B1,
FreeUCardViewModel.B2,
FreeUCardViewModel.S1,
FreeUCardViewModel.S2
)
)
.Output;
}
}
); );
} }
// Override with custom VAE if enabled // Override with custom VAE if enabled
if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false }) /*if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false })
{ {
var customVaeLoader = nodes.AddNamedNode( var customVaeLoader = nodes.AddNamedNode(
ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath) ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath)
); );
builder.Connections.PrimaryVAE = customVaeLoader.Output; builder.Connections.PrimaryVAE = customVaeLoader.Output;
} }*/
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled) /*if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
{ {
// Get new latent size // Get new latent size
var hiresSize = builder.Connections.PrimarySize.WithScale( var hiresSize = builder.Connections.PrimarySize.WithScale(
@ -329,10 +297,10 @@ public class InferenceTextToImageViewModel
// Set as primary // Set as primary
builder.Connections.Primary = hiresSampler.Output; builder.Connections.Primary = hiresSampler.Output;
builder.Connections.PrimarySize = hiresSize; builder.Connections.PrimarySize = hiresSize;
} }*/
// If upscale is enabled, add another upscale group // If upscale is enabled, add another upscale group
if (IsUpscaleEnabled) /*if (IsUpscaleEnabled)
{ {
var upscaleSize = builder.Connections.PrimarySize.WithScale( var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale UpscalerCardViewModel.Scale
@ -349,11 +317,20 @@ public class InferenceTextToImageViewModel
builder.Connections.Primary = upscaleResult; builder.Connections.Primary = upscaleResult;
builder.Connections.PrimarySize = upscaleSize; builder.Connections.PrimarySize = upscaleSize;
} }*/
builder.SetupOutputImage(); builder.SetupOutputImage();
} }
/// <inheritdoc />
protected override IEnumerable<ImageSource> GetInputImages()
{
// TODO support hires in some generic way
return SamplerCardViewModel.ModulesCardViewModel.Cards
.OfType<ControlNetModule>()
.SelectMany(m => m.GetInputImages());
}
/// <inheritdoc /> /// <inheritdoc />
protected override async Task GenerateImageImpl( protected override async Task GenerateImageImpl(
GenerateOverrides overrides, GenerateOverrides overrides,
@ -435,4 +412,22 @@ public class InferenceTextToImageViewModel
return parameters; return parameters;
} }
// Migration for v2 deserialization
public override void LoadStateFromJsonObject(JsonObject state, int version)
{
if (version > 2)
{
LoadStateFromJsonObject(state);
}
ModulesCardViewModel.Clear();
// Add by default the original cards - FreeU, HiresFix, Upscaler
var hiresFix = ModulesCardViewModel.AddModule<HiresFixModule>();
var upscaler = ModulesCardViewModel.AddModule<UpscalerModule>();
hiresFix.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsHiresFixEnabled");
upscaler.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsUpscaleEnabled");
}
} }

62
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -1,20 +1,27 @@
using System; using System;
using System.ComponentModel.DataAnnotations;
using System.Linq; using System.Linq;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ModelCard))] [View(typeof(ModelCard))]
[ManagedService] [ManagedService]
[Transient] [Transient]
public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState public partial class ModelCardViewModel
: LoadableViewModelBase,
IParametersLoadableState,
IComfyStep
{ {
[ObservableProperty] [ObservableProperty]
private HybridModelFile? selectedModel; private HybridModelFile? selectedModel;
@ -42,6 +49,59 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
ClientManager = clientManager; ClientManager = clientManager;
} }
/// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Load base checkpoint
var baseLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CheckpointLoaderSimple
{
Name = "CheckpointLoader",
CkptName =
SelectedModel?.RelativePath
?? throw new ValidationException("Model not selected")
}
);
e.Builder.Connections.BaseModel = baseLoader.Output1;
e.Builder.Connections.BaseClip = baseLoader.Output2;
e.Builder.Connections.BaseVAE = baseLoader.Output3;
// Load refiner
if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false })
{
var refinerLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CheckpointLoaderSimple
{
Name = "Refiner_CheckpointLoader",
CkptName =
SelectedRefiner?.RelativePath
?? throw new ValidationException("Refiner Model enabled but not selected")
}
);
e.Builder.Connections.RefinerModel = refinerLoader.Output1;
e.Builder.Connections.RefinerClip = refinerLoader.Output2;
e.Builder.Connections.RefinerVAE = refinerLoader.Output3;
}
// Load custom VAE
if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false })
{
var vaeLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.VAELoader
{
Name = "VAELoader",
VaeName =
SelectedVae?.RelativePath
?? throw new ValidationException("VAE enabled but not selected")
}
);
e.Builder.Connections.PrimaryVAE = vaeLoader.Output;
}
}
/// <inheritdoc /> /// <inheritdoc />
public override JsonObject SaveStateToJsonObject() public override JsonObject SaveStateToJsonObject()
{ {

52
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs

@ -1,8 +1,13 @@
using StabilityMatrix.Avalonia.Controls; using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
@ -18,9 +23,52 @@ public class ControlNetModule : ModuleBase
AddCards(vmFactory.Get<ControlNetCardViewModel>()); AddCards(vmFactory.Get<ControlNetCardViewModel>());
} }
public IEnumerable<ImageSource> GetInputImages()
{
if (GetCard<ControlNetCardViewModel>().SelectImageCardViewModel.ImageSource is { } image)
{
yield return image;
}
}
/// <inheritdoc /> /// <inheritdoc />
protected override void OnApplyStep(ModuleApplyStepEventArgs e) protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{ {
throw new System.NotImplementedException(); var card = GetCard<ControlNetCardViewModel>();
var imageLoad = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.LoadImage
{
Name = e.Nodes.GetUniqueName("ControlNet_LoadImage"),
Image =
card.SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached(
"Inference"
) ?? throw new ValidationException()
}
);
var controlNetLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.ControlNetLoader
{
Name = e.Nodes.GetUniqueName("ControlNetLoader"),
ControlNetName = card.SelectedModel?.FileName ?? throw new ValidationException(),
}
);
var controlNetApply = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.ControlNetApplyAdvanced
{
Name = e.Nodes.GetUniqueName("ControlNet"),
Image = imageLoad.Output1,
ControlNet = controlNetLoader.Output,
Positive = e.Temp.Conditioning.Positive,
Negative = e.Temp.Conditioning.Negative,
Strength = card.Strength,
StartPercent = card.StartPercent,
EndPercent = card.EndPercent,
}
);
e.Temp.Conditioning = (controlNetApply.Output1, controlNetApply.Output2);
} }
} }

45
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs

@ -0,0 +1,45 @@
using System;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
[ManagedService]
[Transient]
public class FreeUModule : ModuleBase
{
/// <inheritdoc />
public FreeUModule(ServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
Title = "FreeU";
AddCards(vmFactory.Get<FreeUCardViewModel>());
}
/// <summary>
/// Applies FreeU to the Model property
/// </summary>
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var card = GetCard<FreeUCardViewModel>();
e.Temp.Model = e.Nodes
.AddNamedNode(
ComfyNodeBuilder.FreeU(
e.Nodes.GetUniqueName("FreeU"),
e.Temp.Model
?? throw new ArgumentException(
"Temp.Model not set on ModuleApplyStepEventArgs"
),
card.B1,
card.B2,
card.S1,
card.S2
)
)
.Output;
}
}

39
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -1,8 +1,10 @@
using System.Linq; using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
@ -16,7 +18,10 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SamplerCard))] [View(typeof(SamplerCard))]
[ManagedService] [ManagedService]
[Transient] [Transient]
public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState public partial class SamplerCardViewModel
: LoadableViewModelBase,
IParametersLoadableState,
IComfyStep
{ {
public const string ModuleKey = "Sampler"; public const string ModuleKey = "Sampler";
@ -54,15 +59,18 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
private bool isSamplerSelectionEnabled; private bool isSamplerSelectionEnabled;
[ObservableProperty] [ObservableProperty]
[Required]
private ComfySampler? selectedSampler = ComfySampler.EulerAncestral; private ComfySampler? selectedSampler = ComfySampler.EulerAncestral;
[ObservableProperty] [ObservableProperty]
private bool isSchedulerSelectionEnabled; private bool isSchedulerSelectionEnabled;
[ObservableProperty] [ObservableProperty]
[Required]
private ComfyScheduler? selectedScheduler = ComfyScheduler.Normal; private ComfyScheduler? selectedScheduler = ComfyScheduler.Normal;
public StackEditableCardViewModel StackEditableCardViewModel { get; } [JsonPropertyName("Modules")]
public StackEditableCardViewModel ModulesCardViewModel { get; }
[JsonIgnore] [JsonIgnore]
public IInferenceClientManager ClientManager { get; } public IInferenceClientManager ClientManager { get; }
@ -73,12 +81,33 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
) )
{ {
ClientManager = clientManager; ClientManager = clientManager;
StackEditableCardViewModel = vmFactory.Get<StackEditableCardViewModel>(modulesCard => ModulesCardViewModel = vmFactory.Get<StackEditableCardViewModel>(modulesCard =>
{ {
modulesCard.Title = "Addons"; modulesCard.Title = "Addons";
modulesCard.AvailableModules = new[] { typeof(ControlNetModule) }; modulesCard.AvailableModules = new[] { typeof(FreeUModule), typeof(ControlNetModule) };
modulesCard.InitializeDefaults(); modulesCard.InitializeDefaults();
}); });
ModulesCardViewModel.CardAdded += (
(sender, item) =>
{
if (item is ControlNetModule module)
{
// Inherit our edit state
// module.IsEditEnabled = IsEditEnabled;
}
}
);
}
/// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Apply steps from our modules
foreach (var module in ModulesCardViewModel.Cards.Cast<ModuleBase>())
{
module.ApplyStep(e);
}
} }
/// <inheritdoc /> /// <inheritdoc />

Loading…
Cancel
Save