From f781fbb2292be294dbcd7dce4419e4b965ff8ebb Mon Sep 17 00:00:00 2001 From: Ionite Date: Mon, 11 Mar 2024 18:20:55 -0400 Subject: [PATCH] Fix ReferenceOnly ControlNet batch image outputting original as well --- .../Inference/Modules/ControlNetModule.cs | 7 ++++ .../Inference/SamplerCardViewModel.cs | 35 ++++++++++++++----- .../Inference/ModuleApplyStepTemporaryArgs.cs | 10 ++++++ 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs index c170a5d9..21213098 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs @@ -114,6 +114,13 @@ public class ControlNetModule : ModuleBase } e.Temp.Primary = controlNetReferenceOnly.Output2; + // Indicate that the Primary latent has been temp batched + // https://github.com/comfyanonymous/ComfyUI_experiments/issues/11 + + e.Temp.IsPrimaryTempBatched = true; + // Index 0 is the original image, index 1 is the reference only latent + e.Temp.PrimaryTempBatchPickIndex = 1; + return; } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index 3ab194a6..8890a514 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -222,12 +222,9 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo ); e.Builder.Connections.Primary = sampler.Output1; - - return; } - // Use KSampler if no refiner, otherwise need KSamplerAdvanced - if (e.Builder.Connections.Refiner.Model is null) + else if (e.Builder.Connections.Refiner.Model is null) { // No refiner var sampler = e.Nodes.AddTypedNode( @@ -272,8 +269,30 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo } ); + e.Builder.Connections.Primary = sampler.Output; + } + + // If temp batched, add a LatentFromBatch to pick the temp batch right after first sampler + if (e.Temp.IsPrimaryTempBatched) + { + e.Builder.Connections.Primary = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.LatentFromBatch + { + Name = e.Nodes.GetUniqueName("ControlNet_LatentFromBatch"), + Samples = e.Builder.GetPrimaryAsLatent(), + BatchIndex = e.Temp.PrimaryTempBatchPickIndex, + // Use max length here as recommended + // https://github.com/comfyanonymous/ComfyUI_experiments/issues/11 + Length = 64 + } + ).Output; + } + + // Refiner + if (e.Builder.Connections.Refiner.Model is not null) + { // Add refiner sampler - var refinerSampler = e.Nodes.AddTypedNode( + e.Builder.Connections.Primary = e.Nodes.AddTypedNode( new ComfyNodeBuilder.KSamplerAdvanced { Name = "Sampler_Refiner", @@ -287,14 +306,12 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo Positive = refinerConditioning!.Positive, Negative = refinerConditioning.Negative, // Connect to previous sampler - LatentImage = sampler.Output, + LatentImage = e.Builder.GetPrimaryAsLatent(), StartAtStep = Steps, EndAtStep = TotalSteps, ReturnWithLeftoverNoise = false } - ); - - e.Builder.Connections.Primary = refinerSampler.Output; + ).Output; } } diff --git a/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs b/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs index c867fcbc..e76ea875 100644 --- a/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs +++ b/StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs @@ -11,6 +11,16 @@ public class ModuleApplyStepTemporaryArgs public VAENodeConnection? PrimaryVAE { get; set; } + /// + /// Used by Reference-Only ControlNet to indicate that has been batched. + /// + public bool IsPrimaryTempBatched { get; set; } + + /// + /// When is true, this is the index of the temp batch to pick after sampling. + /// + public int PrimaryTempBatchPickIndex { get; set; } + public Dictionary Models { get; set; } = new() { ["Base"] = new ModelConnections("Base"), ["Refiner"] = new ModelConnections("Refiner") };