Browse Source

Reference controlnet refactors

pull/629/head
Ionite 8 months ago
parent
commit
e7be967cdd
No known key found for this signature in database
  1. 36
      StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs
  2. 67
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs
  3. 57
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs
  4. 17
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
  5. 8
      StabilityMatrix.Core/Helper/RemoteModels.cs
  6. 37
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  7. 6
      StabilityMatrix.Core/Models/HybridModelFile.cs
  8. 36
      StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs

36
StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs

@ -4,7 +4,7 @@ using System.IO;
using System.IO.Hashing;
using System.Text;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Inference;
namespace StabilityMatrix.Avalonia.Models.Inference;
@ -17,7 +17,7 @@ public class ModuleApplyStepEventArgs : EventArgs
public NodeDictionary Nodes => Builder.Nodes;
public ModuleApplyStepTemporaryArgs Temp { get; } = new();
public ModuleApplyStepTemporaryArgs Temp { get; set; } = new();
/// <summary>
/// Generation overrides (like hires fix generate, current seed generate, etc.)
@ -26,6 +26,20 @@ public class ModuleApplyStepEventArgs : EventArgs
public List<(string SourcePath, string DestinationRelativePath)> FilesToTransfer { get; init; } = [];
/// <summary>
/// Creates a new <see cref="ModuleApplyStepEventArgs"/> with the given <see cref="ComfyNodeBuilder"/>.
/// </summary>
/// <returns></returns>
public ModuleApplyStepTemporaryArgs CreateTempFromBuilder()
{
return new ModuleApplyStepTemporaryArgs
{
Primary = Builder.Connections.Primary,
PrimaryVAE = Builder.Connections.PrimaryVAE,
Models = Builder.Connections.Models
};
}
public void AddFileTransfer(string sourcePath, string destinationRelativePath)
{
FilesToTransfer.Add((sourcePath, destinationRelativePath));
@ -54,22 +68,4 @@ public class ModuleApplyStepEventArgs : EventArgs
return destPath;
}
public class ModuleApplyStepTemporaryArgs
{
/// <summary>
/// Temporary conditioning apply step, used by samplers to apply control net.
/// </summary>
public ConditioningConnections? Conditioning { get; set; }
/// <summary>
/// Temporary refiner conditioning apply step, used by samplers to apply control net.
/// </summary>
public ConditioningConnections? RefinerConditioning { get; set; }
/// <summary>
/// Temporary model apply step, used by samplers to apply control net.
/// </summary>
public ModelNodeConnection? Model { get; set; }
}
}

67
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs

@ -6,6 +6,9 @@ using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
@ -49,6 +52,56 @@ public class ControlNetModule : ModuleBase
}
);
// If ReferenceOnly is selected, use special node
if (card.SelectedModel == RemoteModels.ControlNetReferenceOnlyModel)
{
// We need to rescale image to be the current primary size if it's not already
var primarySize = e.Builder.Connections.PrimarySize;
if (card.SelectImageCardViewModel.CurrentBitmapSize != primarySize)
{
var scaled = e.Builder.Group_Upscale(
e.Nodes.GetUniqueName("ControlNet_Rescale"),
image,
e.Temp.GetDefaultVAE(),
ComfyUpscaler.NearestExact,
primarySize.Width,
primarySize.Width
);
e.Temp.Primary = scaled;
}
else
{
e.Temp.Primary = image;
}
// Set image as new latent source, add reference only node
var model = e.Temp.GetRefinerOrBaseModel();
var controlNetReferenceOnly = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.ReferenceOnlySimple
{
Name = e.Nodes.GetUniqueName("ControlNet_ReferenceOnly"),
Reference = e.Builder.GetPrimaryAsLatent(
e.Temp.Primary,
e.Builder.Connections.GetDefaultVAE()
),
Model = model
}
);
// Set output as new primary and model source
if (model == e.Temp.Refiner.Model)
{
e.Temp.Refiner.Model = controlNetReferenceOnly.Output1;
}
else
{
e.Temp.Base.Model = controlNetReferenceOnly.Output1;
}
e.Temp.Primary = controlNetReferenceOnly.Output2;
return;
}
var controlNetLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.ControlNetLoader
{
@ -64,18 +117,18 @@ public class ControlNetModule : ModuleBase
Name = e.Nodes.GetUniqueName("ControlNetApply"),
Image = imageLoad.Output1,
ControlNet = controlNetLoader.Output,
Positive = e.Temp.Conditioning?.Positive ?? throw new ArgumentException("No Conditioning"),
Negative = e.Temp.Conditioning?.Negative ?? throw new ArgumentException("No Conditioning"),
Positive = e.Temp.Base.Conditioning!.Unwrap().Positive,
Negative = e.Temp.Base.Conditioning.Negative,
Strength = card.Strength,
StartPercent = card.StartPercent,
EndPercent = card.EndPercent,
}
);
e.Temp.Conditioning = (controlNetApply.Output1, controlNetApply.Output2);
e.Temp.Base.Conditioning = (controlNetApply.Output1, controlNetApply.Output2);
// Refiner if available
if (e.Temp.RefinerConditioning is not null)
if (e.Temp.Refiner.Conditioning is not null)
{
var controlNetRefinerApply = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.ControlNetApplyAdvanced
@ -83,15 +136,15 @@ public class ControlNetModule : ModuleBase
Name = e.Nodes.GetUniqueName("Refiner_ControlNetApply"),
Image = imageLoad.Output1,
ControlNet = controlNetLoader.Output,
Positive = e.Temp.RefinerConditioning.Positive,
Negative = e.Temp.RefinerConditioning.Negative,
Positive = e.Temp.Refiner.Conditioning!.Unwrap().Positive,
Negative = e.Temp.Refiner.Conditioning.Negative,
Strength = card.Strength,
StartPercent = card.StartPercent,
EndPercent = card.EndPercent,
}
);
e.Temp.RefinerConditioning = (controlNetRefinerApply.Output1, controlNetRefinerApply.Output2);
e.Temp.Refiner.Conditioning = (controlNetRefinerApply.Output1, controlNetRefinerApply.Output2);
}
}
}

57
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/HiresFixModule.cs

@ -78,30 +78,39 @@ public partial class HiresFixModule : ModuleBase
);
}
var hiresSampler = builder
.Nodes
.AddTypedNode(
new ComfyNodeBuilder.KSampler
{
Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"),
Model = builder.Connections.GetRefinerOrBaseModel(),
Seed = builder.Connections.Seed,
Steps = samplerCard.Steps,
Cfg = samplerCard.CfgScale,
SamplerName =
samplerCard.SelectedSampler?.Name
?? e.Builder.Connections.PrimarySampler?.Name
?? throw new ArgumentException("No PrimarySampler"),
Scheduler =
samplerCard.SelectedScheduler?.Name
?? e.Builder.Connections.PrimaryScheduler?.Name
?? throw new ArgumentException("No PrimaryScheduler"),
Positive = builder.Connections.GetRefinerOrBaseConditioning().Positive,
Negative = builder.Connections.GetRefinerOrBaseConditioning().Negative,
LatentImage = builder.GetPrimaryAsLatent(),
Denoise = samplerCard.DenoiseStrength
}
);
// If we need to inherit primary sampler addons, use their temp args
if (samplerCard.InheritPrimarySamplerAddons)
{
e.Temp = e.Builder.Connections.BaseSamplerTemporaryArgs ?? e.CreateTempFromBuilder();
}
else
{
// otherwise just use new ones
e.Temp = e.CreateTempFromBuilder();
}
var hiresSampler = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSampler
{
Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"),
Model = builder.Connections.GetRefinerOrBaseModel(),
Seed = builder.Connections.Seed,
Steps = samplerCard.Steps,
Cfg = samplerCard.CfgScale,
SamplerName =
samplerCard.SelectedSampler?.Name
?? e.Builder.Connections.PrimarySampler?.Name
?? throw new ArgumentException("No PrimarySampler"),
Scheduler =
samplerCard.SelectedScheduler?.Name
?? e.Builder.Connections.PrimaryScheduler?.Name
?? throw new ArgumentException("No PrimaryScheduler"),
Positive = e.Temp.GetRefinerOrBaseConditioning().Positive,
Negative = e.Temp.GetRefinerOrBaseConditioning().Negative,
LatentImage = builder.GetPrimaryAsLatent(),
Denoise = samplerCard.DenoiseStrength
}
);
// Set as primary
builder.Connections.Primary = hiresSampler.Output;

17
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -130,8 +130,7 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
}
// Provide temp values
e.Temp.Conditioning = e.Builder.Connections.Base.Conditioning;
e.Temp.RefinerConditioning = e.Builder.Connections.Refiner.Conditioning;
e.Temp = e.CreateTempFromBuilder();
// Apply steps from our addons
ApplyAddonSteps(e);
@ -142,6 +141,9 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
if (!e.Nodes.ContainsKey("Sampler"))
{
ApplyStepsInitialSampler(e);
// Save temp
e.Builder.Connections.BaseSamplerTemporaryArgs = e.Temp;
}
else
{
@ -152,7 +154,10 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
private void ApplyStepsInitialSampler(ModuleApplyStepEventArgs e)
{
// Get primary as latent using vae
var primaryLatent = e.Builder.GetPrimaryAsLatent();
var primaryLatent = e.Builder.GetPrimaryAsLatent(
e.Temp.Primary!.Unwrap(),
e.Builder.Connections.GetDefaultVAE()
);
// Set primary sampler and scheduler
var primarySampler = SelectedSampler ?? throw new ValidationException("Sampler not selected");
@ -162,8 +167,8 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
e.Builder.Connections.PrimaryScheduler = primaryScheduler;
// Use Temp Conditioning that may be modified by addons
var conditioning = e.Temp.Conditioning.Unwrap();
var refinerConditioning = e.Temp.RefinerConditioning;
var conditioning = e.Temp.Base.Conditioning.Unwrap();
var refinerConditioning = e.Temp.Refiner.Conditioning;
// Use custom sampler if SDTurbo scheduler is selected
if (e.Builder.Connections.PrimaryScheduler == ComfyScheduler.SDTurbo)
@ -216,8 +221,6 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
// Use KSampler if no refiner, otherwise need KSamplerAdvanced
if (e.Builder.Connections.Refiner.Model is null)
{
var baseConditioning = e.Builder.Connections.Base.Conditioning.Unwrap();
// No refiner
var sampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSampler

8
StabilityMatrix.Core/Helper/RemoteModels.cs

@ -167,8 +167,14 @@ public static class RemoteModels
)
};
public static HybridModelFile ControlNetReferenceOnlyModel { get; } =
HybridModelFile.FromRemote("@ReferenceOnly");
public static IReadOnlyList<HybridModelFile> ControlNetModels { get; } =
ControlNets.Select(HybridModelFile.FromDownloadable).ToImmutableArray();
ControlNets
.Select(HybridModelFile.FromDownloadable)
.Concat([ControlNetReferenceOnlyModel])
.ToImmutableArray();
private static IEnumerable<RemoteResource> PromptExpansions =>
[

37
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -7,6 +7,7 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.Inference;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
@ -342,6 +343,34 @@ public class ComfyNodeBuilder
public bool LogPrompt { get; init; }
}
[TypedNodeOptions(
Name = "Inference_Core_AIO_Preprocessor",
RequiredExtensions = ["https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes >= 0.2.0"]
)]
public record AIOPreprocessor : ComfyTypedNodeBase<ImageNodeConnection>
{
public required ImageNodeConnection Image { get; init; }
public required string Preprocessor { get; init; }
[Range(64, 2048)]
public int Resolution { get; init; } = 512;
}
[TypedNodeOptions(
Name = "Inference_Core_ReferenceOnlySimple",
RequiredExtensions = ["https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes >= 0.3.0"]
)]
public record ReferenceOnlySimple : ComfyTypedNodeBase<ModelNodeConnection, LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
public required LatentNodeConnection Reference { get; init; }
[Range(1, 64)]
public int BatchSize { get; init; } = 1;
}
public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
{
var name = GetUniqueName("VAEDecode");
@ -818,6 +847,14 @@ public class ComfyNodeBuilder
public ModelConnections Base => Models["Base"];
public ModelConnections Refiner => Models["Refiner"];
public Dictionary<string, ModuleApplyStepTemporaryArgs?> SamplerTemporaryArgs { get; } = new();
public ModuleApplyStepTemporaryArgs? BaseSamplerTemporaryArgs
{
get => SamplerTemporaryArgs.GetValueOrDefault("Base");
set => SamplerTemporaryArgs["Base"] = value;
}
public PrimaryNodeConnection? Primary { get; set; }
public VAENodeConnection? PrimaryVAE { get; set; }
public Size PrimarySize { get; set; }

6
StabilityMatrix.Core/Models/HybridModelFile.cs

@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Core.Models;
@ -67,6 +68,11 @@ public record HybridModelFile
return "Default";
}
if (ReferenceEquals(this, RemoteModels.ControlNetReferenceOnlyModel))
{
return "Reference Only";
}
var fileName = Path.GetFileNameWithoutExtension(RelativePath);
if (

36
StabilityMatrix.Core/Models/Inference/ModuleApplyStepTemporaryArgs.cs

@ -0,0 +1,36 @@
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Core.Models.Inference;
public class ModuleApplyStepTemporaryArgs
{
/// <summary>
/// Temporary Primary apply step, used by ControlNet ReferenceOnly which changes the latent.
/// </summary>
public PrimaryNodeConnection? Primary { get; set; }
public VAENodeConnection? PrimaryVAE { get; set; }
public Dictionary<string, ModelConnections> Models { get; set; } =
new() { ["Base"] = new ModelConnections("Base"), ["Refiner"] = new ModelConnections("Refiner") };
public ModelConnections Base => Models["Base"];
public ModelConnections Refiner => Models["Refiner"];
public ConditioningConnections GetRefinerOrBaseConditioning()
{
return Refiner.Conditioning
?? Base.Conditioning
?? throw new NullReferenceException("No Refiner or Base Conditioning");
}
public ModelNodeConnection GetRefinerOrBaseModel()
{
return Refiner.Model ?? Base.Model ?? throw new NullReferenceException("No Refiner or Base Model");
}
public VAENodeConnection GetDefaultVAE()
{
return PrimaryVAE ?? Refiner.VAE ?? Base.VAE ?? throw new NullReferenceException("No VAE");
}
}
Loading…
Cancel
Save