From e7be967cdd91bac703b07b0fe5eb5bf5fa8c0e8c Mon Sep 17 00:00:00 2001 From: Ionite Date: Mon, 11 Mar 2024 16:56:13 -0400 Subject: [PATCH] Reference controlnet refactors --- .../Inference/ModuleApplyStepEventArgs.cs | 36 +++++----- .../Inference/Modules/ControlNetModule.cs | 67 +++++++++++++++++-- .../Inference/Modules/HiresFixModule.cs | 57 +++++++++------- .../Inference/SamplerCardViewModel.cs | 17 +++-- StabilityMatrix.Core/Helper/RemoteModels.cs | 8 ++- .../Api/Comfy/Nodes/ComfyNodeBuilder.cs | 37 ++++++++++ .../Models/HybridModelFile.cs | 6 ++ .../Inference/ModuleApplyStepTemporaryArgs.cs | 36 ++++++++++ 8 files changed, 205 insertions(+), 59 deletions(-) create mode 100644 StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs diff --git a/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs b/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs index 375ca17f..8da5fe6d 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs @@ -4,7 +4,7 @@ using System.IO; using System.IO.Hashing; using System.Text; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; -using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; +using StabilityMatrix.Core.Models.Inference; namespace StabilityMatrix.Avalonia.Models.Inference; @@ -17,7 +17,7 @@ public class ModuleApplyStepEventArgs : EventArgs public NodeDictionary Nodes => Builder.Nodes; - public ModuleApplyStepTemporaryArgs Temp { get; } = new(); + public ModuleApplyStepTemporaryArgs Temp { get; set; } = new(); /// /// Generation overrides (like hires fix generate, current seed generate, etc.) @@ -26,6 +26,20 @@ public class ModuleApplyStepEventArgs : EventArgs public List<(string SourcePath, string DestinationRelativePath)> FilesToTransfer { get; init; } = []; + /// + /// Creates a new with the given . + /// + /// + public ModuleApplyStepTemporaryArgs CreateTempFromBuilder() + { + return new ModuleApplyStepTemporaryArgs + { + Primary = Builder.Connections.Primary, + PrimaryVAE = Builder.Connections.PrimaryVAE, + Models = Builder.Connections.Models + }; + } + public void AddFileTransfer(string sourcePath, string destinationRelativePath) { FilesToTransfer.Add((sourcePath, destinationRelativePath)); @@ -54,22 +68,4 @@ public class ModuleApplyStepEventArgs : EventArgs return destPath; } - - public class ModuleApplyStepTemporaryArgs - { - /// - /// Temporary conditioning apply step, used by samplers to apply control net. - /// - public ConditioningConnections? Conditioning { get; set; } - - /// - /// Temporary refiner conditioning apply step, used by samplers to apply control net. - /// - public ConditioningConnections? RefinerConditioning { get; set; } - - /// - /// Temporary model apply step, used by samplers to apply control net. - /// - public ModelNodeConnection? Model { get; set; } - } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs index 23260daf..542f7a1f 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs @@ -6,6 +6,9 @@ using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; @@ -49,6 +52,56 @@ public class ControlNetModule : ModuleBase } ); + // If ReferenceOnly is selected, use special node + if (card.SelectedModel == RemoteModels.ControlNetReferenceOnlyModel) + { + // We need to rescale image to be the current primary size if it's not already + var primarySize = e.Builder.Connections.PrimarySize; + if (card.SelectImageCardViewModel.CurrentBitmapSize != primarySize) + { + var scaled = e.Builder.Group_Upscale( + e.Nodes.GetUniqueName("ControlNet_Rescale"), + image, + e.Temp.GetDefaultVAE(), + ComfyUpscaler.NearestExact, + primarySize.Width, + primarySize.Width + ); + e.Temp.Primary = scaled; + } + else + { + e.Temp.Primary = image; + } + + // Set image as new latent source, add reference only node + var model = e.Temp.GetRefinerOrBaseModel(); + var controlNetReferenceOnly = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ReferenceOnlySimple + { + Name = e.Nodes.GetUniqueName("ControlNet_ReferenceOnly"), + Reference = e.Builder.GetPrimaryAsLatent( + e.Temp.Primary, + e.Builder.Connections.GetDefaultVAE() + ), + Model = model + } + ); + + // Set output as new primary and model source + if (model == e.Temp.Refiner.Model) + { + e.Temp.Refiner.Model = controlNetReferenceOnly.Output1; + } + else + { + e.Temp.Base.Model = controlNetReferenceOnly.Output1; + } + e.Temp.Primary = controlNetReferenceOnly.Output2; + + return; + } + var controlNetLoader = e.Nodes.AddTypedNode( new ComfyNodeBuilder.ControlNetLoader { @@ -64,18 +117,18 @@ public class ControlNetModule : ModuleBase Name = e.Nodes.GetUniqueName("ControlNetApply"), Image = imageLoad.Output1, ControlNet = controlNetLoader.Output, - Positive = e.Temp.Conditioning?.Positive ?? throw new ArgumentException("No Conditioning"), - Negative = e.Temp.Conditioning?.Negative ?? throw new ArgumentException("No Conditioning"), + Positive = e.Temp.Base.Conditioning!.Unwrap().Positive, + Negative = e.Temp.Base.Conditioning.Negative, Strength = card.Strength, StartPercent = card.StartPercent, EndPercent = card.EndPercent, } ); - e.Temp.Conditioning = (controlNetApply.Output1, controlNetApply.Output2); + e.Temp.Base.Conditioning = (controlNetApply.Output1, controlNetApply.Output2); // Refiner if available - if (e.Temp.RefinerConditioning is not null) + if (e.Temp.Refiner.Conditioning is not null) { var controlNetRefinerApply = e.Nodes.AddTypedNode( new ComfyNodeBuilder.ControlNetApplyAdvanced @@ -83,15 +136,15 @@ public class ControlNetModule : ModuleBase Name = e.Nodes.GetUniqueName("Refiner_ControlNetApply"), Image = imageLoad.Output1, ControlNet = controlNetLoader.Output, - Positive = e.Temp.RefinerConditioning.Positive, - Negative = e.Temp.RefinerConditioning.Negative, + Positive = e.Temp.Refiner.Conditioning!.Unwrap().Positive, + Negative = e.Temp.Refiner.Conditioning.Negative, Strength = card.Strength, StartPercent = card.StartPercent, EndPercent = card.EndPercent, } ); - e.Temp.RefinerConditioning = (controlNetRefinerApply.Output1, controlNetRefinerApply.Output2); + e.Temp.Refiner.Conditioning = (controlNetRefinerApply.Output1, controlNetRefinerApply.Output2); } } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs index 588cd263..94a0d1bc 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs @@ -78,30 +78,39 @@ public partial class HiresFixModule : ModuleBase ); } - var hiresSampler = builder - .Nodes - .AddTypedNode( - new ComfyNodeBuilder.KSampler - { - Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"), - Model = builder.Connections.GetRefinerOrBaseModel(), - Seed = builder.Connections.Seed, - Steps = samplerCard.Steps, - Cfg = samplerCard.CfgScale, - SamplerName = - samplerCard.SelectedSampler?.Name - ?? e.Builder.Connections.PrimarySampler?.Name - ?? throw new ArgumentException("No PrimarySampler"), - Scheduler = - samplerCard.SelectedScheduler?.Name - ?? e.Builder.Connections.PrimaryScheduler?.Name - ?? throw new ArgumentException("No PrimaryScheduler"), - Positive = builder.Connections.GetRefinerOrBaseConditioning().Positive, - Negative = builder.Connections.GetRefinerOrBaseConditioning().Negative, - LatentImage = builder.GetPrimaryAsLatent(), - Denoise = samplerCard.DenoiseStrength - } - ); + // If we need to inherit primary sampler addons, use their temp args + if (samplerCard.InheritPrimarySamplerAddons) + { + e.Temp = e.Builder.Connections.BaseSamplerTemporaryArgs ?? e.CreateTempFromBuilder(); + } + else + { + // otherwise just use new ones + e.Temp = e.CreateTempFromBuilder(); + } + + var hiresSampler = builder.Nodes.AddTypedNode( + new ComfyNodeBuilder.KSampler + { + Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"), + Model = builder.Connections.GetRefinerOrBaseModel(), + Seed = builder.Connections.Seed, + Steps = samplerCard.Steps, + Cfg = samplerCard.CfgScale, + SamplerName = + samplerCard.SelectedSampler?.Name + ?? e.Builder.Connections.PrimarySampler?.Name + ?? throw new ArgumentException("No PrimarySampler"), + Scheduler = + samplerCard.SelectedScheduler?.Name + ?? e.Builder.Connections.PrimaryScheduler?.Name + ?? throw new ArgumentException("No PrimaryScheduler"), + Positive = e.Temp.GetRefinerOrBaseConditioning().Positive, + Negative = e.Temp.GetRefinerOrBaseConditioning().Negative, + LatentImage = builder.GetPrimaryAsLatent(), + Denoise = samplerCard.DenoiseStrength + } + ); // Set as primary builder.Connections.Primary = hiresSampler.Output; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index c2409259..eb1b443a 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -130,8 +130,7 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo } // Provide temp values - e.Temp.Conditioning = e.Builder.Connections.Base.Conditioning; - e.Temp.RefinerConditioning = e.Builder.Connections.Refiner.Conditioning; + e.Temp = e.CreateTempFromBuilder(); // Apply steps from our addons ApplyAddonSteps(e); @@ -142,6 +141,9 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo if (!e.Nodes.ContainsKey("Sampler")) { ApplyStepsInitialSampler(e); + + // Save temp + e.Builder.Connections.BaseSamplerTemporaryArgs = e.Temp; } else { @@ -152,7 +154,10 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo private void ApplyStepsInitialSampler(ModuleApplyStepEventArgs e) { // Get primary as latent using vae - var primaryLatent = e.Builder.GetPrimaryAsLatent(); + var primaryLatent = e.Builder.GetPrimaryAsLatent( + e.Temp.Primary!.Unwrap(), + e.Builder.Connections.GetDefaultVAE() + ); // Set primary sampler and scheduler var primarySampler = SelectedSampler ?? throw new ValidationException("Sampler not selected"); @@ -162,8 +167,8 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo e.Builder.Connections.PrimaryScheduler = primaryScheduler; // Use Temp Conditioning that may be modified by addons - var conditioning = e.Temp.Conditioning.Unwrap(); - var refinerConditioning = e.Temp.RefinerConditioning; + var conditioning = e.Temp.Base.Conditioning.Unwrap(); + var refinerConditioning = e.Temp.Refiner.Conditioning; // Use custom sampler if SDTurbo scheduler is selected if (e.Builder.Connections.PrimaryScheduler == ComfyScheduler.SDTurbo) @@ -216,8 +221,6 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo // Use KSampler if no refiner, otherwise need KSamplerAdvanced if (e.Builder.Connections.Refiner.Model is null) { - var baseConditioning = e.Builder.Connections.Base.Conditioning.Unwrap(); - // No refiner var sampler = e.Nodes.AddTypedNode( new ComfyNodeBuilder.KSampler diff --git a/StabilityMatrix.Core/Helper/RemoteModels.cs b/StabilityMatrix.Core/Helper/RemoteModels.cs index 3f6a49b5..deed5fb1 100644 --- a/StabilityMatrix.Core/Helper/RemoteModels.cs +++ b/StabilityMatrix.Core/Helper/RemoteModels.cs @@ -167,8 +167,14 @@ public static class RemoteModels ) }; + public static HybridModelFile ControlNetReferenceOnlyModel { get; } = + HybridModelFile.FromRemote("@ReferenceOnly"); + public static IReadOnlyList ControlNetModels { get; } = - ControlNets.Select(HybridModelFile.FromDownloadable).ToImmutableArray(); + ControlNets + .Select(HybridModelFile.FromDownloadable) + .Concat([ControlNetReferenceOnlyModel]) + .ToImmutableArray(); private static IEnumerable PromptExpansions => [ diff --git a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs index 286dff0b..b0040a5d 100644 --- a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs +++ b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs @@ -7,6 +7,7 @@ using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Models.Inference; namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes; @@ -342,6 +343,34 @@ public class ComfyNodeBuilder public bool LogPrompt { get; init; } } + [TypedNodeOptions( + Name = "Inference_Core_AIO_Preprocessor", + RequiredExtensions = ["https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes >= 0.2.0"] + )] + public record AIOPreprocessor : ComfyTypedNodeBase + { + public required ImageNodeConnection Image { get; init; } + + public required string Preprocessor { get; init; } + + [Range(64, 2048)] + public int Resolution { get; init; } = 512; + } + + [TypedNodeOptions( + Name = "Inference_Core_ReferenceOnlySimple", + RequiredExtensions = ["https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes >= 0.3.0"] + )] + public record ReferenceOnlySimple : ComfyTypedNodeBase + { + public required ModelNodeConnection Model { get; init; } + + public required LatentNodeConnection Reference { get; init; } + + [Range(1, 64)] + public int BatchSize { get; init; } = 1; + } + public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae) { var name = GetUniqueName("VAEDecode"); @@ -818,6 +847,14 @@ public class ComfyNodeBuilder public ModelConnections Base => Models["Base"]; public ModelConnections Refiner => Models["Refiner"]; + public Dictionary SamplerTemporaryArgs { get; } = new(); + + public ModuleApplyStepTemporaryArgs? BaseSamplerTemporaryArgs + { + get => SamplerTemporaryArgs.GetValueOrDefault("Base"); + set => SamplerTemporaryArgs["Base"] = value; + } + public PrimaryNodeConnection? Primary { get; set; } public VAENodeConnection? PrimaryVAE { get; set; } public Size PrimarySize { get; set; } diff --git a/StabilityMatrix.Core/Models/HybridModelFile.cs b/StabilityMatrix.Core/Models/HybridModelFile.cs index 5e0d7275..fbe81d27 100644 --- a/StabilityMatrix.Core/Models/HybridModelFile.cs +++ b/StabilityMatrix.Core/Models/HybridModelFile.cs @@ -1,5 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; +using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models.Database; namespace StabilityMatrix.Core.Models; @@ -67,6 +68,11 @@ public record HybridModelFile return "Default"; } + if (ReferenceEquals(this, RemoteModels.ControlNetReferenceOnlyModel)) + { + return "Reference Only"; + } + var fileName = Path.GetFileNameWithoutExtension(RelativePath); if ( diff --git a/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs b/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs new file mode 100644 index 00000000..c867fcbc --- /dev/null +++ b/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs @@ -0,0 +1,36 @@ +using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; + +namespace StabilityMatrix.Core.Models.Inference; + +public class ModuleApplyStepTemporaryArgs +{ + /// + /// Temporary Primary apply step, used by ControlNet ReferenceOnly which changes the latent. + /// + public PrimaryNodeConnection? Primary { get; set; } + + public VAENodeConnection? PrimaryVAE { get; set; } + + public Dictionary Models { get; set; } = + new() { ["Base"] = new ModelConnections("Base"), ["Refiner"] = new ModelConnections("Refiner") }; + + public ModelConnections Base => Models["Base"]; + public ModelConnections Refiner => Models["Refiner"]; + + public ConditioningConnections GetRefinerOrBaseConditioning() + { + return Refiner.Conditioning + ?? Base.Conditioning + ?? throw new NullReferenceException("No Refiner or Base Conditioning"); + } + + public ModelNodeConnection GetRefinerOrBaseModel() + { + return Refiner.Model ?? Base.Model ?? throw new NullReferenceException("No Refiner or Base Model"); + } + + public VAENodeConnection GetDefaultVAE() + { + return PrimaryVAE ?? Refiner.VAE ?? Base.VAE ?? throw new NullReferenceException("No VAE"); + } +}