Browse Source

Add ControlNet models to InferenceClientManager

pull/333/head
Ionite 1 year ago
parent
commit
33f445f486
No known key found for this signature in database
  1. 3
      StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
  2. 1
      StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
  3. 38
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  4. 5
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

3
StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

@ -24,6 +24,9 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults);

1
StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

@ -43,6 +43,7 @@ public interface IInferenceClientManager
IObservableCollection<HybridModelFile> Models { get; }
IObservableCollection<HybridModelFile> VaeModels { get; }
IObservableCollection<HybridModelFile> ControlNetModels { get; }
IObservableCollection<ComfySampler> Samplers { get; }
IObservableCollection<ComfyUpscaler> Upscalers { get; }
IObservableCollection<ComfyScheduler> Schedulers { get; }

38
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -65,6 +65,12 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<HybridModelFile, string> controlNetModelsSource =
new(p => p.GetId());
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name);
public IObservableCollection<ComfySampler> Samplers { get; } =
@ -110,6 +116,17 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Bind(Models)
.Subscribe();
controlNetModelsSource
.Connect()
.SortBy(
f => f.ShortDisplayName,
SortDirection.Ascending,
SortOptimisations.ComparesImmutableValuesOnly
)
.DeferUntilLoaded()
.Bind(ControlNetModels)
.Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe();
@ -159,6 +176,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
if (!IsConnected)
throw new InvalidOperationException("Client is not connected");
// Get model names
if (await Client.GetModelNamesAsync() is { } modelNames)
{
modelsSource.EditDiff(
@ -167,6 +185,18 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
);
}
// Get control net model names
if (
await Client.GetNodeOptionNamesAsync("ControlNetLoader", "control_net_name") is
{ } controlNetModelNames
)
{
controlNetModelsSource.EditDiff(
controlNetModelNames.Select(HybridModelFile.FromRemote),
HybridModelFile.Comparer
);
}
// Fetch sampler names from KSampler node
if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{
@ -229,6 +259,14 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
HybridModelFile.Comparer
);
// Load local control net models
controlNetModelsSource.EditDiff(
modelIndexService
.GetFromModelIndex(SharedFolderType.ControlNet)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
// Load local VAE models
vaeModelsSource.EditDiff(
modelIndexService

5
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -348,6 +348,11 @@ public class ComfyNodeBuilder
};
}
public class ControlNetLoader : ComfyTypedNodeBase<ControlNetNodeConnection>
{
public required string ControlNetName { get; init; }
}
public class ControlNetApplyAdvanced
: ComfyTypedNodeBase<ConditioningNodeConnection, ConditioningNodeConnection>
{

Loading…
Cancel
Save