From 9afe848cdb69ed5ad9b271a42ec4b8a483f0ea58 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sat, 28 Oct 2023 23:31:03 -0400 Subject: [PATCH] Implement IComfyStep apply for modules --- .../Extensions/ComfyNodeBuilderExtensions.cs | 56 ++++++++---- .../Models/IJsonLoadableState.cs | 2 + .../Models/ImageSource.cs | 5 ++ .../Inference/ModuleApplyStepEventArgs.cs | 24 ++++-- .../Base/InferenceGenerationViewModelBase.cs | 17 ++++ .../ViewModels/Base/LoadableViewModelBase.cs | 18 +++- .../InferenceTextToImageViewModel.cs | 85 +++++++++---------- .../Inference/ModelCardViewModel.cs | 62 +++++++++++++- .../Inference/Modules/ControlNetModule.cs | 52 +++++++++++- .../Inference/Modules/FreeUModule.cs | 45 ++++++++++ .../Inference/SamplerCardViewModel.cs | 39 +++++++-- 11 files changed, 326 insertions(+), 79 deletions(-) create mode 100644 StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs index 27cf0a97..e5b2d272 100644 --- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs +++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.ComponentModel.DataAnnotations; using System.Drawing; +using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; @@ -58,7 +59,7 @@ public static class ComfyNodeBuilderExtensions Action? postModelLoad = null ) { - // Load base checkpoint + /*// Load base checkpoint var checkpointLoader = builder.Nodes.AddNamedNode( ComfyNodeBuilder.CheckpointLoaderSimple( "CheckpointLoader", @@ -73,7 +74,7 @@ public static class ComfyNodeBuilderExtensions builder.Connections.PrimaryVAE = builder.Connections.BaseVAE; // Run post model load action - postModelLoad?.Invoke(builder); + postModelLoad?.Invoke(builder);*/ // Load prompts var prompt = promptCardViewModel.GetPrompt(); @@ -115,6 +116,28 @@ public static class ComfyNodeBuilderExtensions builder.Connections.BaseConditioning = positiveClip.Output; builder.Connections.BaseNegativeConditioning = negativeClip.Output; + // Apply sampler addons (FreeU / ControlNet) to model and conditioning + var samplerStepArgs = new ModuleApplyStepEventArgs + { + Builder = builder, + Temp = + { + Model = builder.Connections.BaseModel, + Conditioning = (positiveClip.Output, negativeClip.Output) + } + }; + + samplerCardViewModel.ApplyStep(samplerStepArgs); + var model = samplerStepArgs.Temp.Model; + var conditioning = samplerStepArgs.Temp.Conditioning; + + // Primary latent encoding vae + var vae = + builder.Connections.PrimaryVAE + ?? builder.Connections.BaseVAE + ?? throw new ValidationException("No Primary or Base VAE"); + var latent = builder.GetPrimaryAsLatent(vae); + // Add base sampler (without refiner) if ( modelCardViewModel @@ -124,18 +147,17 @@ public static class ComfyNodeBuilderExtensions var sampler = builder.Nodes.AddNamedNode( ComfyNodeBuilder.KSampler( "Sampler", - builder.Connections.BaseModel, + model, builder.Connections.Seed, samplerCardViewModel.Steps, samplerCardViewModel.CfgScale, samplerCardViewModel.SelectedSampler ?? throw new ValidationException("Sampler not selected"), samplerCardViewModel.SelectedScheduler - ?? throw new ValidationException("Sampler not selected"), - positiveClip.Output, - negativeClip.Output, - builder.GetPrimaryAsLatent() - ?? throw new ValidationException("Latent source not set"), + ?? throw new ValidationException("Scheduler not selected"), + conditioning.Positive, + conditioning.Negative, + latent, samplerCardViewModel.DenoiseStrength ) ); @@ -150,7 +172,7 @@ public static class ComfyNodeBuilderExtensions var sampler = builder.Nodes.AddNamedNode( ComfyNodeBuilder.KSamplerAdvanced( "Sampler", - builder.Connections.BaseModel, + model, true, builder.Connections.Seed, totalSteps, @@ -159,9 +181,9 @@ public static class ComfyNodeBuilderExtensions ?? throw new ValidationException("Sampler not selected"), samplerCardViewModel.SelectedScheduler ?? throw new ValidationException("Sampler not selected"), - positiveClip.Output, - negativeClip.Output, - builder.GetPrimaryAsLatent(), + conditioning.Positive, + conditioning.Negative, + latent, 0, samplerCardViewModel.Steps, true @@ -180,7 +202,7 @@ public static class ComfyNodeBuilderExtensions Action? postModelLoad = null ) { - // Load refiner checkpoint + /*// Load refiner checkpoint var checkpointLoader = builder.Nodes.AddNamedNode( ComfyNodeBuilder.CheckpointLoaderSimple( "Refiner_CheckpointLoader", @@ -195,7 +217,7 @@ public static class ComfyNodeBuilderExtensions builder.Connections.PrimaryVAE = builder.Connections.RefinerVAE; // Run post model load action - postModelLoad?.Invoke(builder); + postModelLoad?.Invoke(builder);*/ // Load prompts var prompt = promptCardViewModel.GetPrompt(); @@ -272,7 +294,11 @@ public static class ComfyNodeBuilderExtensions new ComfyNodeBuilder.PreviewImage { Name = "SaveImage", - Images = builder.GetPrimaryAsImage() + Images = builder.GetPrimaryAsImage( + builder.Connections.PrimaryVAE + ?? builder.Connections.RefinerVAE + ?? builder.Connections.BaseVAE + ) } ); diff --git a/StabilityMatrix.Avalonia/Models/IJsonLoadableState.cs b/StabilityMatrix.Avalonia/Models/IJsonLoadableState.cs index 95d4860d..7b48bd97 100644 --- a/StabilityMatrix.Avalonia/Models/IJsonLoadableState.cs +++ b/StabilityMatrix.Avalonia/Models/IJsonLoadableState.cs @@ -4,6 +4,8 @@ namespace StabilityMatrix.Avalonia.Models; public interface IJsonLoadableState { + void LoadStateFromJsonObject(JsonObject state, int version); + void LoadStateFromJsonObject(JsonObject state); JsonObject SaveStateToJsonObject(); diff --git a/StabilityMatrix.Avalonia/Models/ImageSource.cs b/StabilityMatrix.Avalonia/Models/ImageSource.cs index 9ba3d828..d5a33b87 100644 --- a/StabilityMatrix.Avalonia/Models/ImageSource.cs +++ b/StabilityMatrix.Avalonia/Models/ImageSource.cs @@ -128,6 +128,11 @@ public record ImageSource : IDisposable return guid + extension; } + public string GetHashGuidFileNameCached(string pathPrefix) + { + return Path.Combine(pathPrefix, GetHashGuidFileNameCached()); + } + /// /// Clears the cached bitmap /// diff --git a/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs b/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs index 2949787a..662b9915 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs @@ -14,13 +14,7 @@ public class ModuleApplyStepEventArgs : EventArgs public NodeDictionary Nodes => Builder.Nodes; - /// - /// Temporary conditioning apply step, used by samplers to apply control net. - /// - public ( - ConditioningNodeConnection Positive, - ConditioningNodeConnection Negative - ) Conditioning { get; set; } + public ModuleApplyStepTemporaryArgs Temp { get; } = new(); /// /// Index of the step in the pipeline. @@ -37,4 +31,20 @@ public class ModuleApplyStepEventArgs : EventArgs /// public IReadOnlyDictionary IsEnabledOverrides { get; init; } = new Dictionary(); + + public class ModuleApplyStepTemporaryArgs + { + /// + /// Temporary conditioning apply step, used by samplers to apply control net. + /// + public ( + ConditioningNodeConnection Positive, + ConditioningNodeConnection Negative + ) Conditioning { get; set; } + + /// + /// Temporary model apply step, used by samplers to apply control net. + /// + public ModelNodeConnection? Model { get; set; } + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index fae045e4..d27b1500 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -22,6 +22,7 @@ using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Inference; +using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; @@ -602,5 +603,21 @@ public abstract partial class InferenceGenerationViewModelBase public ComfyNodeBuilder Builder { get; } = new(); public GenerateOverrides Overrides { get; init; } = new(); public long? SeedOverride { get; init; } + + public static implicit operator ModuleApplyStepEventArgs(BuildPromptEventArgs args) + { + var overrides = new Dictionary(); + + if (args.Overrides.IsHiresFixEnabled.HasValue) + { + overrides[typeof(HiresFixModule)] = args.Overrides.IsHiresFixEnabled.Value; + } + + return new ModuleApplyStepEventArgs + { + Builder = args.Builder, + IsEnabledOverrides = overrides + }; + } } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs index 1ca569c2..aae0199c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs @@ -1,4 +1,5 @@ using System; +using System.ComponentModel; using System.Linq; using System.Reflection; using System.Text.Json; @@ -13,11 +14,15 @@ using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; namespace StabilityMatrix.Avalonia.ViewModels.Base; -[JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)] [JsonDerivedType(typeof(StackExpanderViewModel), StackExpanderViewModel.ModuleKey)] -[JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)] [JsonDerivedType(typeof(SamplerCardViewModel), SamplerCardViewModel.ModuleKey)] +[JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)] +[JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)] +[JsonDerivedType(typeof(ControlNetCardViewModel), ControlNetCardViewModel.ModuleKey)] +[JsonDerivedType(typeof(FreeUModule))] +[JsonDerivedType(typeof(HiresFixModule))] [JsonDerivedType(typeof(UpscalerModule))] +[JsonDerivedType(typeof(ControlNetModule))] public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -28,10 +33,10 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState typeof(IRelayCommand) }; - private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors), }; + private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) }; private static readonly JsonSerializerOptions SerializerOptions = - new() { IgnoreReadOnlyProperties = true, }; + new() { IgnoreReadOnlyProperties = true }; private static bool ShouldIgnoreProperty(PropertyInfo property) { @@ -280,6 +285,11 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState return state; } + public virtual void LoadStateFromJsonObject(JsonObject state, int version) + { + LoadStateFromJsonObject(state); + } + /// /// Serialize a model to a JSON object. /// diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 4ceee51e..0fa7042c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.ComponentModel.DataAnnotations; using System.Drawing; using System.Linq; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; @@ -209,33 +210,18 @@ public class InferenceTextToImageViewModel builder.Connections.Seed = Convert.ToUInt64(SeedCardViewModel.Seed); } + // Load models + ModelCardViewModel.ApplyStep(args); + // Setup empty latent builder.SetupLatentSource(BatchSizeCardViewModel, SamplerCardViewModel); - // Setup base stage + // Setup base sampling stage builder.SetupBaseSampler( SamplerCardViewModel, PromptCardViewModel, ModelCardViewModel, - modelIndexService, - postModelLoad: x => - { - if (IsFreeUEnabled) - { - builder.Connections.BaseModel = nodes - .AddNamedNode( - ComfyNodeBuilder.FreeU( - "FreeU", - x.Connections.BaseModel!, - FreeUCardViewModel.B1, - FreeUCardViewModel.B2, - FreeUCardViewModel.S1, - FreeUCardViewModel.S2 - ) - ) - .Output; - } - } + modelIndexService ); // Setup refiner stage if enabled @@ -248,40 +234,22 @@ public class InferenceTextToImageViewModel SamplerCardViewModel, PromptCardViewModel, ModelCardViewModel, - modelIndexService, - postModelLoad: x => - { - if (IsFreeUEnabled) - { - builder.Connections.RefinerModel = nodes - .AddNamedNode( - ComfyNodeBuilder.FreeU( - "Refiner_FreeU", - x.Connections.RefinerModel!, - FreeUCardViewModel.B1, - FreeUCardViewModel.B2, - FreeUCardViewModel.S1, - FreeUCardViewModel.S2 - ) - ) - .Output; - } - } + modelIndexService ); } // Override with custom VAE if enabled - if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false }) + /*if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false }) { var customVaeLoader = nodes.AddNamedNode( ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath) ); builder.Connections.PrimaryVAE = customVaeLoader.Output; - } + }*/ // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node - if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled) + /*if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled) { // Get new latent size var hiresSize = builder.Connections.PrimarySize.WithScale( @@ -329,10 +297,10 @@ public class InferenceTextToImageViewModel // Set as primary builder.Connections.Primary = hiresSampler.Output; builder.Connections.PrimarySize = hiresSize; - } + }*/ // If upscale is enabled, add another upscale group - if (IsUpscaleEnabled) + /*if (IsUpscaleEnabled) { var upscaleSize = builder.Connections.PrimarySize.WithScale( UpscalerCardViewModel.Scale @@ -349,11 +317,20 @@ public class InferenceTextToImageViewModel builder.Connections.Primary = upscaleResult; builder.Connections.PrimarySize = upscaleSize; - } + }*/ builder.SetupOutputImage(); } + /// + protected override IEnumerable GetInputImages() + { + // TODO support hires in some generic way + return SamplerCardViewModel.ModulesCardViewModel.Cards + .OfType() + .SelectMany(m => m.GetInputImages()); + } + /// protected override async Task GenerateImageImpl( GenerateOverrides overrides, @@ -435,4 +412,22 @@ public class InferenceTextToImageViewModel return parameters; } + + // Migration for v2 deserialization + public override void LoadStateFromJsonObject(JsonObject state, int version) + { + if (version > 2) + { + LoadStateFromJsonObject(state); + } + + ModulesCardViewModel.Clear(); + + // Add by default the original cards - FreeU, HiresFix, Upscaler + var hiresFix = ModulesCardViewModel.AddModule(); + var upscaler = ModulesCardViewModel.AddModule(); + + hiresFix.IsEnabled = state.GetPropertyValueOrDefault("IsHiresFixEnabled"); + upscaler.IsEnabled = state.GetPropertyValueOrDefault("IsUpscaleEnabled"); + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs index d69f5b80..4ccc3868 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs @@ -1,20 +1,27 @@ using System; +using System.ComponentModel.DataAnnotations; using System.Linq; using System.Text.Json.Nodes; using CommunityToolkit.Mvvm.ComponentModel; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api.Comfy.Nodes; +using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(ModelCard))] [ManagedService] [Transient] -public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState +public partial class ModelCardViewModel + : LoadableViewModelBase, + IParametersLoadableState, + IComfyStep { [ObservableProperty] private HybridModelFile? selectedModel; @@ -42,6 +49,59 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad ClientManager = clientManager; } + /// + public void ApplyStep(ModuleApplyStepEventArgs e) + { + // Load base checkpoint + var baseLoader = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.CheckpointLoaderSimple + { + Name = "CheckpointLoader", + CkptName = + SelectedModel?.RelativePath + ?? throw new ValidationException("Model not selected") + } + ); + + e.Builder.Connections.BaseModel = baseLoader.Output1; + e.Builder.Connections.BaseClip = baseLoader.Output2; + e.Builder.Connections.BaseVAE = baseLoader.Output3; + + // Load refiner + if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false }) + { + var refinerLoader = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.CheckpointLoaderSimple + { + Name = "Refiner_CheckpointLoader", + CkptName = + SelectedRefiner?.RelativePath + ?? throw new ValidationException("Refiner Model enabled but not selected") + } + ); + + e.Builder.Connections.RefinerModel = refinerLoader.Output1; + e.Builder.Connections.RefinerClip = refinerLoader.Output2; + e.Builder.Connections.RefinerVAE = refinerLoader.Output3; + } + + // Load custom VAE + if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false }) + { + var vaeLoader = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.VAELoader + { + Name = "VAELoader", + VaeName = + SelectedVae?.RelativePath + ?? throw new ValidationException("VAE enabled but not selected") + } + ); + + e.Builder.Connections.PrimaryVAE = vaeLoader.Output; + } + } + /// public override JsonObject SaveStateToJsonObject() { diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs index e6278412..20c5973f 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs @@ -1,8 +1,13 @@ -using StabilityMatrix.Avalonia.Controls; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models.Api.Comfy.Nodes; namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; @@ -18,9 +23,52 @@ public class ControlNetModule : ModuleBase AddCards(vmFactory.Get()); } + public IEnumerable GetInputImages() + { + if (GetCard().SelectImageCardViewModel.ImageSource is { } image) + { + yield return image; + } + } + /// protected override void OnApplyStep(ModuleApplyStepEventArgs e) { - throw new System.NotImplementedException(); + var card = GetCard(); + + var imageLoad = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.LoadImage + { + Name = e.Nodes.GetUniqueName("ControlNet_LoadImage"), + Image = + card.SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached( + "Inference" + ) ?? throw new ValidationException() + } + ); + + var controlNetLoader = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ControlNetLoader + { + Name = e.Nodes.GetUniqueName("ControlNetLoader"), + ControlNetName = card.SelectedModel?.FileName ?? throw new ValidationException(), + } + ); + + var controlNetApply = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ControlNetApplyAdvanced + { + Name = e.Nodes.GetUniqueName("ControlNet"), + Image = imageLoad.Output1, + ControlNet = controlNetLoader.Output, + Positive = e.Temp.Conditioning.Positive, + Negative = e.Temp.Conditioning.Negative, + Strength = card.Strength, + StartPercent = card.StartPercent, + EndPercent = card.EndPercent, + } + ); + + e.Temp.Conditioning = (controlNetApply.Output1, controlNetApply.Output2); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs new file mode 100644 index 00000000..622e0e7d --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/FreeUModule.cs @@ -0,0 +1,45 @@ +using System; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models.Api.Comfy.Nodes; + +namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; + +[ManagedService] +[Transient] +public class FreeUModule : ModuleBase +{ + /// + public FreeUModule(ServiceManager vmFactory) + : base(vmFactory) + { + Title = "FreeU"; + AddCards(vmFactory.Get()); + } + + /// + /// Applies FreeU to the Model property + /// + protected override void OnApplyStep(ModuleApplyStepEventArgs e) + { + var card = GetCard(); + + e.Temp.Model = e.Nodes + .AddNamedNode( + ComfyNodeBuilder.FreeU( + e.Nodes.GetUniqueName("FreeU"), + e.Temp.Model + ?? throw new ArgumentException( + "Temp.Model not set on ModuleApplyStepEventArgs" + ), + card.B1, + card.B2, + card.S1, + card.S2 + ) + ) + .Output; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index 5c7b4ac4..30ff44e1 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -1,8 +1,10 @@ -using System.Linq; +using System.ComponentModel.DataAnnotations; +using System.Linq; using System.Text.Json.Serialization; using CommunityToolkit.Mvvm.ComponentModel; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; @@ -16,7 +18,10 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(SamplerCard))] [ManagedService] [Transient] -public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState +public partial class SamplerCardViewModel + : LoadableViewModelBase, + IParametersLoadableState, + IComfyStep { public const string ModuleKey = "Sampler"; @@ -54,15 +59,18 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo private bool isSamplerSelectionEnabled; [ObservableProperty] + [Required] private ComfySampler? selectedSampler = ComfySampler.EulerAncestral; [ObservableProperty] private bool isSchedulerSelectionEnabled; [ObservableProperty] + [Required] private ComfyScheduler? selectedScheduler = ComfyScheduler.Normal; - public StackEditableCardViewModel StackEditableCardViewModel { get; } + [JsonPropertyName("Modules")] + public StackEditableCardViewModel ModulesCardViewModel { get; } [JsonIgnore] public IInferenceClientManager ClientManager { get; } @@ -73,12 +81,33 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo ) { ClientManager = clientManager; - StackEditableCardViewModel = vmFactory.Get(modulesCard => + ModulesCardViewModel = vmFactory.Get(modulesCard => { modulesCard.Title = "Addons"; - modulesCard.AvailableModules = new[] { typeof(ControlNetModule) }; + modulesCard.AvailableModules = new[] { typeof(FreeUModule), typeof(ControlNetModule) }; modulesCard.InitializeDefaults(); }); + + ModulesCardViewModel.CardAdded += ( + (sender, item) => + { + if (item is ControlNetModule module) + { + // Inherit our edit state + // module.IsEditEnabled = IsEditEnabled; + } + } + ); + } + + /// + public void ApplyStep(ModuleApplyStepEventArgs e) + { + // Apply steps from our modules + foreach (var module in ModulesCardViewModel.Cards.Cast()) + { + module.ApplyStep(e); + } } ///