Browse Source

Add ControlNet models to InferenceClientManager

pull/333/head
Ionite 1 year ago
parent
commit
33f445f486
No known key found for this signature in database
  1. 3
      StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
  2. 1
      StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
  3. 38
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  4. 5
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

3
StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

@ -24,6 +24,9 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
public IObservableCollection<HybridModelFile> VaeModels { get; } = public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>(); new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<ComfySampler> Samplers { get; } = public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults); new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults);

1
StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

@ -43,6 +43,7 @@ public interface IInferenceClientManager
IObservableCollection<HybridModelFile> Models { get; } IObservableCollection<HybridModelFile> Models { get; }
IObservableCollection<HybridModelFile> VaeModels { get; } IObservableCollection<HybridModelFile> VaeModels { get; }
IObservableCollection<HybridModelFile> ControlNetModels { get; }
IObservableCollection<ComfySampler> Samplers { get; } IObservableCollection<ComfySampler> Samplers { get; }
IObservableCollection<ComfyUpscaler> Upscalers { get; } IObservableCollection<ComfyUpscaler> Upscalers { get; }
IObservableCollection<ComfyScheduler> Schedulers { get; } IObservableCollection<ComfyScheduler> Schedulers { get; }

38
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -65,6 +65,12 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
public IObservableCollection<HybridModelFile> VaeModels { get; } = public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>(); new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<HybridModelFile, string> controlNetModelsSource =
new(p => p.GetId());
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name); private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name);
public IObservableCollection<ComfySampler> Samplers { get; } = public IObservableCollection<ComfySampler> Samplers { get; } =
@ -110,6 +116,17 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Bind(Models) .Bind(Models)
.Subscribe(); .Subscribe();
controlNetModelsSource
.Connect()
.SortBy(
f => f.ShortDisplayName,
SortDirection.Ascending,
SortOptimisations.ComparesImmutableValuesOnly
)
.DeferUntilLoaded()
.Bind(ControlNetModels)
.Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default); vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe(); vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe();
@ -159,6 +176,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
if (!IsConnected) if (!IsConnected)
throw new InvalidOperationException("Client is not connected"); throw new InvalidOperationException("Client is not connected");
// Get model names
if (await Client.GetModelNamesAsync() is { } modelNames) if (await Client.GetModelNamesAsync() is { } modelNames)
{ {
modelsSource.EditDiff( modelsSource.EditDiff(
@ -167,6 +185,18 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
); );
} }
// Get control net model names
if (
await Client.GetNodeOptionNamesAsync("ControlNetLoader", "control_net_name") is
{ } controlNetModelNames
)
{
controlNetModelsSource.EditDiff(
controlNetModelNames.Select(HybridModelFile.FromRemote),
HybridModelFile.Comparer
);
}
// Fetch sampler names from KSampler node // Fetch sampler names from KSampler node
if (await Client.GetSamplerNamesAsync() is { } samplerNames) if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{ {
@ -229,6 +259,14 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
HybridModelFile.Comparer HybridModelFile.Comparer
); );
// Load local control net models
controlNetModelsSource.EditDiff(
modelIndexService
.GetFromModelIndex(SharedFolderType.ControlNet)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
// Load local VAE models // Load local VAE models
vaeModelsSource.EditDiff( vaeModelsSource.EditDiff(
modelIndexService modelIndexService

5
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -348,6 +348,11 @@ public class ComfyNodeBuilder
}; };
} }
public class ControlNetLoader : ComfyTypedNodeBase<ControlNetNodeConnection>
{
public required string ControlNetName { get; init; }
}
public class ControlNetApplyAdvanced public class ControlNetApplyAdvanced
: ComfyTypedNodeBase<ConditioningNodeConnection, ConditioningNodeConnection> : ComfyTypedNodeBase<ConditioningNodeConnection, ConditioningNodeConnection>
{ {

Loading…
Cancel
Save