Browse Source

Merge pull request #531 from ionite34/prompt-expansion

Prompt expansion
pull/495/head
Ionite 9 months ago committed by GitHub
parent
commit
75cdf9b673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      CHANGELOG.md
  2. 1
      StabilityMatrix.Avalonia/App.axaml
  3. 50
      StabilityMatrix.Avalonia/Controls/Inference/PromptCard.axaml
  4. 96
      StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml
  5. 52
      StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml.cs
  6. 18
      StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
  7. 8
      StabilityMatrix.Avalonia/Models/Inference/PromptCardModel.cs
  8. 11
      StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
  9. 40
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  10. 1
      StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
  11. 131
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  12. 14
      StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs
  13. 1
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs
  14. 1
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  15. 40
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/PromptExpansionModule.cs
  16. 50
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  17. 47
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptExpansionCardViewModel.cs
  18. 12
      StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs
  19. 2
      StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
  20. 10
      StabilityMatrix.Core/Api/ApiFactory.cs
  21. 6
      StabilityMatrix.Core/Api/IApiFactory.cs
  22. 14
      StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs
  23. 49
      StabilityMatrix.Core/Converters/Json/NodeConnectionBaseJsonConverter.cs
  24. 63
      StabilityMatrix.Core/Converters/Json/OneOfJsonConverter.cs
  25. 22
      StabilityMatrix.Core/Helper/RemoteModels.cs
  26. 28
      StabilityMatrix.Core/Inference/ComfyClient.cs
  27. 18
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnectionBase.cs
  28. 26
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnections.cs
  29. 37
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  30. 25
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs
  31. 33
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs
  32. 1
      StabilityMatrix.Core/Models/HybridModelFile.cs
  33. 10
      StabilityMatrix.Core/Models/Packages/ComfyUI.cs
  34. 5
      StabilityMatrix.Core/Models/Packages/Extensions/ExtensionManifest.cs
  35. 78
      StabilityMatrix.Core/Models/Packages/Extensions/GitPackageExtensionManager.cs
  36. 29
      StabilityMatrix.Core/Models/Packages/Extensions/IPackageExtensionManager.cs
  37. 2
      StabilityMatrix.Core/Models/SharedFolderType.cs
  38. 2
      StabilityMatrix.Core/Models/TrackedDownload.cs

1
CHANGELOG.md

@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
## v2.9.0-dev.3
### Added
- Added Inference Prompt Styles, with Prompt Expansion model support (i.e. Fooocus V2)
- Added copy image support on linux and macOS for Inference outputs viewer menu
### Fixed
- Fixed StableSwarmUI not installing properly on macOS

1
StabilityMatrix.Avalonia/App.axaml

@ -78,6 +78,7 @@
<StyleInclude Source="Controls/Inference/SharpenCard.axaml"/>
<StyleInclude Source="Controls/Inference/FreeUCard.axaml"/>
<StyleInclude Source="Controls/Inference/ControlNetCard.axaml"/>
<StyleInclude Source="Controls/Inference/PromptExpansionCard.axaml"/>
<Style Selector="DockControl">
<Setter Property="(DockProperties.ControlRecycling)" Value="{StaticResource ControlRecyclingKey}" />

50
StabilityMatrix.Avalonia/Controls/Inference/PromptCard.axaml

@ -33,7 +33,7 @@
</Style>
</controls:Card.Styles>
<Grid RowDefinitions="*,16,*">
<Grid RowDefinitions="*,16,*,16,Auto">
<!-- Prompt -->
<Grid ColumnDefinitions="*,Auto" RowDefinitions="Auto,*">
<StackPanel
@ -137,6 +137,54 @@
</Border>
</Grid>
<GridSplitter
Grid.Row="3"
MaxWidth="45"
VerticalAlignment="Center"
BorderThickness="1"
CornerRadius="4"
Opacity="0.3" />
<controls:StackEditableCard
Margin="2,0,0,0"
DataContext="{Binding ModulesCardViewModel}"
Grid.Row="4">
</controls:StackEditableCard>
<!-- Styles and Prompt Expansions -->
<!--<Grid Grid.Row="4" RowDefinitions="Auto,*">
<StackPanel Margin="4,0,4,8" Orientation="Horizontal">
<TextBlock FontSize="14" Text="Styles" />
<icons:Icon
Margin="8,0"
FontSize="10"
Value="fa-solid fa-caret-down" />
</StackPanel>
<Border
Grid.Row="1"
Classes="theme-dark"
VerticalAlignment="Stretch"
CornerRadius="4">
<avaloniaEdit:TextEditor
x:Name="ExtraPromptEditor"
Document="{Binding NegativePromptDocument}"
FontFamily="Cascadia Code,Consolas,Menlo,Monospace">
<i:Interaction.Behaviors>
<behaviors:TextEditorCompletionBehavior
CompletionProvider="{Binding CompletionProvider}"
IsEnabled="{Binding IsAutoCompletionEnabled}"
TokenizerProvider="{Binding TokenizerProvider}" />
<behaviors:TextEditorToolTipBehavior IsEnabled="False" TokenizerProvider="{Binding TokenizerProvider}" />
</i:Interaction.Behaviors>
</avaloniaEdit:TextEditor>
</Border>
</Grid>-->
</Grid>
</controls:Card>
</ControlTemplate>

96
StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml

@ -0,0 +1,96 @@
<Styles
xmlns="https://github.com/avaloniaui"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
xmlns:converters="clr-namespace:StabilityMatrix.Avalonia.Converters"
xmlns:inference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData"
xmlns:models="clr-namespace:StabilityMatrix.Core.Models;assembly=StabilityMatrix.Core"
xmlns:sg="clr-namespace:SpacedGridControl.Avalonia;assembly=SpacedGridControl.Avalonia"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:input="clr-namespace:FluentAvalonia.UI.Input;assembly=FluentAvalonia"
xmlns:fluentIcons="clr-namespace:FluentIcons.Avalonia.Fluent;assembly=FluentIcons.Avalonia.Fluent"
xmlns:local="clr-namespace:StabilityMatrix.Avalonia"
x:DataType="inference:PromptExpansionCardViewModel">
<Design.PreviewWith>
<Panel Width="400" Height="200">
<StackPanel Width="300" VerticalAlignment="Center">
<controls:PromptExpansionCard />
</StackPanel>
</Panel>
</Design.PreviewWith>
<Style Selector="controls|PromptExpansionCard">
<Setter Property="HorizontalAlignment" Value="Stretch" />
<Setter Property="Template">
<ControlTemplate>
<controls:Card Padding="12">
<sg:SpacedGrid
ColumnDefinitions="Auto,*"
ColumnSpacing="8"
RowDefinitions="*,*,*,*"
RowSpacing="0">
<!-- Model -->
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="{x:Static lang:Resources.Label_Model}"
TextAlignment="Left" />
<ui:FAComboBox
x:Name="PART_ModelComboBox"
Grid.Row="0"
Grid.Column="1"
HorizontalAlignment="Stretch"
ItemContainerTheme="{StaticResource FAComboBoxItemHybridModelTheme}"
ItemsSource="{Binding ClientManager.PromptExpansionModels}"
SelectedItem="{Binding SelectedModel}">
<ui:FAComboBox.Resources>
<input:StandardUICommand x:Key="RemoteDownloadCommand"
Command="{Binding RemoteDownloadCommand}" />
</ui:FAComboBox.Resources>
<ui:FAComboBox.DataTemplates>
<controls:HybridModelTemplateSelector>
<DataTemplate x:Key="{x:Static models:HybridModelType.Downloadable}" DataType="models:HybridModelFile">
<Grid ColumnDefinitions="*,Auto">
<TextBlock Foreground="{DynamicResource ThemeGreyColor}" Text="{Binding ShortDisplayName}" />
<Button
Grid.Column="1"
Margin="8,0,0,0"
Padding="0"
Classes="transparent-full">
<fluentIcons:SymbolIcon
VerticalAlignment="Center"
FontSize="18"
Foreground="{DynamicResource ThemeGreyColor}"
IsFilled="True"
Symbol="CloudArrowDown" />
</Button>
</Grid>
</DataTemplate>
<DataTemplate x:Key="{x:Static models:HybridModelType.None}" DataType="models:HybridModelFile">
<TextBlock Text="{Binding ShortDisplayName}" />
</DataTemplate>
</controls:HybridModelTemplateSelector>
</ui:FAComboBox.DataTemplates>
</ui:FAComboBox>
<!--<controls:BetterComboBox
Grid.Row="0"
Grid.Column="1"
Padding="8,6,4,6"
HorizontalAlignment="Stretch"
ItemsSource="{Binding ClientManager.Upscalers}"
SelectedItem="{Binding SelectedModel}"
Theme="{StaticResource BetterComboBoxHybridModelTheme}" />-->
</sg:SpacedGrid>
</controls:Card>
</ControlTemplate>
</Setter>
</Style>
</Styles>

52
StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml.cs

@ -0,0 +1,52 @@
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Controls;
[Transient]
public class PromptExpansionCard : TemplatedControl
{
/// <inheritdoc />
protected override void OnApplyTemplate(TemplateAppliedEventArgs e)
{
base.OnApplyTemplate(e);
var upscalerComboBox = e.NameScope.Find("PART_ModelComboBox") as FAComboBox;
upscalerComboBox!.SelectionChanged += UpscalerComboBox_OnSelectionChanged;
}
private void UpscalerComboBox_OnSelectionChanged(object? sender, SelectionChangedEventArgs e)
{
if (e.AddedItems.Count == 0)
return;
var item = e.AddedItems[0];
if (item is HybridModelFile { IsDownloadable: true })
{
// Reset the selection
e.Handled = true;
if (
e.RemovedItems.Count > 0
&& e.RemovedItems[0] is HybridModelFile { IsDownloadable: false } removedItem
)
{
(sender as FAComboBox)!.SelectedItem = removedItem;
}
else
{
(sender as FAComboBox)!.SelectedItem = null;
}
// Show dialog to download the model
(DataContext as PromptExpansionCardViewModel)!
.RemoteDownloadCommand.ExecuteAsync(item)
.SafeFireAndForget();
}
}
}

18
StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

@ -27,6 +27,9 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults);
@ -64,28 +67,19 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
}
/// <inheritdoc />
public Task CopyImageToInputAsync(
FilePath imageFile,
CancellationToken cancellationToken = default
)
public Task CopyImageToInputAsync(FilePath imageFile, CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
/// <inheritdoc />
public Task UploadInputImageAsync(
ImageSource image,
CancellationToken cancellationToken = default
)
public Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
/// <inheritdoc />
public Task WriteImageToInputAsync(
ImageSource imageSource,
CancellationToken cancellationToken = default
)
public Task WriteImageToInputAsync(ImageSource imageSource, CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}

8
StabilityMatrix.Avalonia/Models/Inference/PromptCardModel.cs

@ -1,10 +1,10 @@
using System.Text.Json.Serialization;
using System.Text.Json.Nodes;
namespace StabilityMatrix.Avalonia.Models.Inference;
[JsonSerializable(typeof(PromptCardModel))]
public class PromptCardModel
{
public string? Prompt { get; set; }
public string? NegativePrompt { get; set; }
public string? Prompt { get; init; }
public string? NegativePrompt { get; init; }
public JsonObject? ModulesCardState { get; init; }
}

11
StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

@ -12,10 +12,7 @@ using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Avalonia.Services;
public interface IInferenceClientManager
: IDisposable,
INotifyPropertyChanged,
INotifyPropertyChanging
public interface IInferenceClientManager : IDisposable, INotifyPropertyChanged, INotifyPropertyChanging
{
ComfyClient? Client { get; set; }
@ -43,6 +40,7 @@ public interface IInferenceClientManager
IObservableCollection<HybridModelFile> Models { get; }
IObservableCollection<HybridModelFile> VaeModels { get; }
IObservableCollection<HybridModelFile> ControlNetModels { get; }
IObservableCollection<HybridModelFile> PromptExpansionModels { get; }
IObservableCollection<ComfySampler> Samplers { get; }
IObservableCollection<ComfyUpscaler> Upscalers { get; }
IObservableCollection<ComfyScheduler> Schedulers { get; }
@ -51,10 +49,7 @@ public interface IInferenceClientManager
Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default);
Task WriteImageToInputAsync(
ImageSource imageSource,
CancellationToken cancellationToken = default
);
Task WriteImageToInputAsync(ImageSource imageSource, CancellationToken cancellationToken = default);
Task ConnectAsync(CancellationToken cancellationToken = default);

40
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -74,6 +74,14 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<HybridModelFile, string> promptExpansionModelsSource = new(p => p.GetId());
private readonly SourceCache<HybridModelFile, string> downloadablePromptExpansionModelsSource =
new(p => p.GetId());
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name);
public IObservableCollection<ComfySampler> Samplers { get; } =
@ -130,6 +138,18 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Bind(ControlNetModels)
.Subscribe();
promptExpansionModelsSource
.Connect()
.Or(downloadablePromptExpansionModelsSource.Connect())
.Sort(
SortExpressionComparer<HybridModelFile>
.Ascending(f => f.Type)
.ThenByAscending(f => f.ShortDisplayName)
)
.DeferUntilLoaded()
.Bind(PromptExpansionModels)
.Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe();
@ -199,6 +219,8 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
);
}
// Prompt Expansion indexing is local only
// Fetch sampler names from KSampler node
if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{
@ -277,6 +299,22 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
);
downloadableControlNetModelsSource.EditDiff(downloadableControlNets, HybridModelFile.Comparer);
// Load local prompt expansion models
promptExpansionModelsSource.EditDiff(
modelIndexService
.GetFromModelIndex(SharedFolderType.PromptExpansion)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
// Downloadable PromptExpansion models
downloadablePromptExpansionModelsSource.EditDiff(
RemoteModels.PromptExpansionModels.Where(
u => !promptExpansionModelsSource.Lookup(u.GetId()).HasValue
),
HybridModelFile.Comparer
);
// Load local VAE models
vaeModelsSource.EditDiff(
modelIndexService.GetFromModelIndex(SharedFolderType.VAE).Select(HybridModelFile.FromLocal),
@ -481,7 +519,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
await ConnectAsyncImpl(uri, cancellationToken);
// Set package path as server path
Client.LocalServerPackage = packagePair;
Client.LocalServerPath = packagePair.InstalledPackage.FullPath!;
}

1
StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

@ -83,6 +83,7 @@
<PackageReference Include="NLog" Version="5.2.8" />
<PackageReference Include="NLog.Extensions.Logging" Version="5.3.8" />
<PackageReference Include="NSubstitute" Version="5.1.0" />
<PackageReference Include="OneOf" Version="3.0.263" />
<PackageReference Include="Polly" Version="8.2.1" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="Polly.Extensions.Http" Version="3.0.0" />

131
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -15,11 +15,15 @@ using Avalonia.Controls.Notifications;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using ExifLibrary;
using FluentAvalonia.UI.Controls;
using Microsoft.Extensions.DependencyInjection;
using Nito.Disposables.Internals;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
@ -35,6 +39,8 @@ using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Packages.Extensions;
using StabilityMatrix.Core.Models.Settings;
using StabilityMatrix.Core.Services;
using Notification = DesktopNotifications.Notification;
@ -272,6 +278,15 @@ public abstract partial class InferenceGenerationViewModelBase
if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null");
// Only check extensions for first batch index
if (args.BatchIndex == 0)
{
if (!await CheckPromptExtensionsInstalled(args.Nodes))
{
throw new ValidationException("Prompt extensions not installed");
}
}
// Upload input images
await UploadInputImages(client);
@ -621,6 +636,121 @@ public abstract partial class InferenceGenerationViewModelBase
return ClientManager.IsConnected;
}
/// <summary>
/// Shows a dialog and return false if prompt required extensions not installed
/// </summary>
private async Task<bool> CheckPromptExtensionsInstalled(NodeDictionary nodeDictionary)
{
// Get prompt required extensions
// Just static for now but could do manifest lookup when we support custom workflows
var requiredExtensions = nodeDictionary
.ClassTypeRequiredExtensions.Values.SelectMany(x => x)
.ToHashSet();
// Skip if no extensions required
if (requiredExtensions.Count == 0)
{
return true;
}
// Get installed extensions
var localPackagePair = ClientManager.Client?.LocalServerPackage.Unwrap()!;
var manager = localPackagePair.BasePackage.ExtensionManager.Unwrap();
var localExtensions = (
await ((GitPackageExtensionManager)manager).GetInstalledExtensionsLiteAsync(
localPackagePair.InstalledPackage
)
).ToImmutableArray();
var missingExtensions = requiredExtensions
.Except(localExtensions.Select(ext => ext.GitRepositoryUrl).WhereNotNull())
.ToImmutableArray();
if (missingExtensions.Length == 0)
{
return true;
}
var dialog = DialogHelper.CreateMarkdownDialog(
$"#### The following extensions are required for this workflow:\n"
+ $"{string.Join("\n- ", missingExtensions)}",
"Install Required Extensions?"
);
dialog.IsPrimaryButtonEnabled = true;
dialog.DefaultButton = ContentDialogButton.Primary;
dialog.PrimaryButtonText =
$"{Resources.Action_Install} ({localPackagePair.InstalledPackage.DisplayName.ToRepr()} will restart)";
dialog.CloseButtonText = Resources.Action_Cancel;
if (await dialog.ShowAsync() == ContentDialogResult.Primary)
{
var manifestExtensionsMap = await manager.GetManifestExtensionsMapAsync(
manager.GetManifests(localPackagePair.InstalledPackage)
);
var steps = new List<IPackageStep>();
foreach (var missingExtensionUrl in missingExtensions)
{
if (!manifestExtensionsMap.TryGetValue(missingExtensionUrl, out var extension))
{
Logger.Warn(
"Extension {MissingExtensionUrl} not found in manifests",
missingExtensionUrl
);
continue;
}
steps.Add(new InstallExtensionStep(manager, localPackagePair.InstalledPackage, extension));
}
var runner = new PackageModificationRunner
{
ShowDialogOnStart = true,
ModificationCompleteTitle = "Extensions Installed",
ModificationCompleteMessage = "Finished installing required extensions"
};
EventManager.Instance.OnPackageInstallProgressAdded(runner);
runner
.ExecuteSteps(steps)
.ContinueWith(async _ =>
{
if (runner.Failed)
return;
// Restart Package
// TODO: This should be handled by some DI package manager service
var launchPage = App.Services.GetRequiredService<LaunchPageViewModel>();
try
{
await Dispatcher.UIThread.InvokeAsync(async () =>
{
await launchPage.Stop();
await launchPage.LaunchAsync();
});
}
catch (Exception e)
{
Logger.Error(e, "Error while restarting package");
notificationService.ShowPersistent(
new AppException(
"Could not restart package",
"Please manually restart the package for extension changes to take effect"
)
);
}
})
.SafeFireAndForget();
}
return false;
}
/// <summary>
/// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery.
@ -683,6 +813,7 @@ public abstract partial class InferenceGenerationViewModelBase
public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList<string> OutputNodeNames { get; init; }
public int BatchIndex { get; init; }
public GenerationParameters? Parameters { get; init; }
public InferenceProjectDocument? Project { get; init; }
public bool ClearOutputImages { get; init; } = true;

14
StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs

@ -19,11 +19,13 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(ControlNetCardViewModel), ControlNetCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(PromptExpansionCardViewModel), PromptExpansionCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(FreeUModule))]
[JsonDerivedType(typeof(HiresFixModule))]
[JsonDerivedType(typeof(UpscalerModule))]
[JsonDerivedType(typeof(ControlNetModule))]
[JsonDerivedType(typeof(SaveImageModule))]
[JsonDerivedType(typeof(PromptExpansionModule))]
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -32,7 +34,8 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) };
private static readonly JsonSerializerOptions SerializerOptions = new() { IgnoreReadOnlyProperties = true };
private static readonly JsonSerializerOptions SerializerOptions =
new() { IgnoreReadOnlyProperties = true };
private static bool ShouldIgnoreProperty(PropertyInfo property)
{
@ -243,7 +246,11 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
}
else
{
Logger.ConditionalTrace("Serializing {Property} ({Type})", property.Name, property.PropertyType);
Logger.ConditionalTrace(
"Serializing {Property} ({Type})",
property.Name,
property.PropertyType
);
var value = property.GetValue(this);
if (value is not null)
{
@ -266,7 +273,8 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
protected static JsonObject SerializeModel<T>(T model)
{
var node = JsonSerializer.SerializeToNode(model);
return node?.AsObject() ?? throw new NullReferenceException("Failed to serialize state to JSON object.");
return node?.AsObject()
?? throw new NullReferenceException("Failed to serialize state to JSON object.");
}
/// <summary>

1
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs

@ -200,6 +200,7 @@ public partial class InferenceImageToVideoViewModel
Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch
ClearOutputImages = i == 0
};

1
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -214,6 +214,7 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, I
Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch
ClearOutputImages = i == 0
};

40
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/PromptExpansionModule.cs

@ -0,0 +1,40 @@
using System;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
[ManagedService]
[Transient]
public class PromptExpansionModule : ModuleBase
{
public PromptExpansionModule(ServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
Title = "Prompt Expansion";
AddCards(vmFactory.Get<PromptExpansionCardViewModel>());
}
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var promptExpansionCard = GetCard<PromptExpansionCardViewModel>();
var model =
promptExpansionCard.SelectedModel
?? throw new InvalidOperationException($"{Title}: Model not selected");
e.Builder.Connections.PositivePrompt = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.PromptExpansion
{
Name = e.Nodes.GetUniqueName("PromptExpansion_Positive"),
ModelName = model.RelativePath,
Text = e.Builder.Connections.PositivePrompt,
Seed = e.Builder.Connections.Seed,
LogPrompt = promptExpansionCard.IsLogOutputEnabled
}
).Output;
}
}

50
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,5 +1,4 @@
using System;
using System.Linq;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
@ -13,12 +12,13 @@ using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
@ -43,6 +43,8 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new();
public StackEditableCardViewModel ModulesCardViewModel { get; }
[ObservableProperty]
private bool isAutoCompletionEnabled;
@ -52,6 +54,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
ITokenizerProvider tokenizerProvider,
ISettingsManager settingsManager,
IModelIndexService modelIndexService,
ServiceManager<ViewModelBase> vmFactory,
SharedState sharedState
)
{
@ -60,6 +63,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
TokenizerProvider = tokenizerProvider;
SharedState = sharedState;
ModulesCardViewModel = vmFactory.Get<StackEditableCardViewModel>(vm =>
{
vm.Title = "Styles";
vm.AvailableModules = [typeof(PromptExpansionModule)];
});
settingsManager.RelayPropertyFor(
this,
vm => vm.IsAutoCompletionEnabled,
@ -84,8 +93,14 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
// Load prompts
var positivePrompt = GetPrompt();
positivePrompt.Process();
e.Builder.Connections.PositivePrompt = positivePrompt.ProcessedText;
var negativePrompt = GetNegativePrompt();
negativePrompt.Process();
e.Builder.Connections.NegativePrompt = negativePrompt.ProcessedText;
// Apply modules / styles that may modify the prompt
ModulesCardViewModel.ApplyStep(e);
foreach (var modelConnections in e.Builder.Connections.Models.Values)
{
@ -98,7 +113,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList();
// Add group to load loras onto model and clip in series
var lorasGroup = e.Builder.Group_LoraLoadMany($"Loras_{modelConnections.Name}", model, clip, loras);
var lorasGroup = e.Builder.Group_LoraLoadMany(
$"Loras_{modelConnections.Name}",
model,
clip,
loras
);
// Set last outputs as model and clip
modelConnections.Model = lorasGroup.Output1;
@ -111,7 +131,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
{
Name = $"PositiveCLIP_{modelConnections.Name}",
Clip = e.Builder.Connections.Base.Clip!,
Text = positivePrompt.ProcessedText
Text = e.Builder.Connections.PositivePrompt
}
);
var negativeClip = e.Nodes.AddTypedNode(
@ -119,7 +139,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
{
Name = $"NegativeCLIP_{modelConnections.Name}",
Clip = e.Builder.Connections.Base.Clip!,
Text = negativePrompt.ProcessedText
Text = e.Builder.Connections.NegativePrompt
}
);
@ -319,7 +339,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
public override JsonObject SaveStateToJsonObject()
{
return SerializeModel(
new PromptCardModel { Prompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text }
new PromptCardModel
{
Prompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text,
ModulesCardState = ModulesCardViewModel.SaveStateToJsonObject()
}
);
}
@ -330,6 +355,11 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
PromptDocument.Text = model.Prompt ?? "";
NegativePromptDocument.Text = model.NegativePrompt ?? "";
if (model.ModulesCardState is not null)
{
ModulesCardViewModel.LoadStateFromJsonObject(model.ModulesCardState);
}
}
/// <inheritdoc />
@ -342,6 +372,10 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with { PositivePrompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text };
return parameters with
{
PositivePrompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text
};
}
}

47
StabilityMatrix.Avalonia/ViewModels/Inference/PromptExpansionCardViewModel.cs

@ -0,0 +1,47 @@
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptExpansionCard))]
[ManagedService]
[Transient]
public partial class PromptExpansionCardViewModel(
IInferenceClientManager clientManager,
ServiceManager<ViewModelBase> vmFactory
) : LoadableViewModelBase
{
public const string ModuleKey = "PromptExpansion";
public IInferenceClientManager ClientManager { get; } = clientManager;
[ObservableProperty]
private HybridModelFile? selectedModel;
[ObservableProperty]
private bool isLogOutputEnabled = true;
[RelayCommand]
private async Task RemoteDownload(HybridModelFile? model)
{
if (model?.DownloadableResource is not { } resource)
return;
var confirmDialog = vmFactory.Get<DownloadResourceViewModel>();
confirmDialog.Resource = resource;
confirmDialog.FileName = resource.FileName;
if (await confirmDialog.GetDialog().ShowAsync() == ContentDialogResult.Primary)
{
confirmDialog.StartDownload();
}
}
}

12
StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs

@ -5,6 +5,7 @@ using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
@ -16,7 +17,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(StackEditableCard))]
[ManagedService]
[Transient]
public partial class StackEditableCardViewModel : StackViewModelBase
public partial class StackEditableCardViewModel : StackViewModelBase, IComfyStep
{
private readonly ServiceManager<ViewModelBase> vmFactory;
@ -68,6 +69,15 @@ public partial class StackEditableCardViewModel : StackViewModelBase
}
}
/// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e)
{
foreach (var module in Cards.OfType<IComfyStep>())
{
module.ApplyStep(e);
}
}
/// <inheritdoc />
protected override void OnCardAdded(LoadableViewModelBase item)
{

2
StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs

@ -206,7 +206,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
}
[RelayCommand]
private async Task LaunchAsync(string? command = null)
public async Task LaunchAsync(string? command = null)
{
await notificationService.TryAsync(LaunchImpl(command));
}

10
StabilityMatrix.Core/Api/ApiFactory.cs

@ -6,7 +6,7 @@ public class ApiFactory : IApiFactory
{
private readonly IHttpClientFactory httpClientFactory;
public RefitSettings? RefitSettings { get; init; }
public ApiFactory(IHttpClientFactory httpClientFactory)
{
this.httpClientFactory = httpClientFactory;
@ -18,4 +18,12 @@ public class ApiFactory : IApiFactory
httpClient.BaseAddress = baseAddress;
return RestService.For<T>(httpClient, RefitSettings);
}
public T CreateRefitClient<T>(Uri baseAddress, RefitSettings refitSettings)
{
var httpClient = httpClientFactory.CreateClient(nameof(T));
httpClient.BaseAddress = baseAddress;
return RestService.For<T>(httpClient, refitSettings);
}
}

6
StabilityMatrix.Core/Api/IApiFactory.cs

@ -1,6 +1,10 @@
namespace StabilityMatrix.Core.Api;
using Refit;
namespace StabilityMatrix.Core.Api;
public interface IApiFactory
{
public T CreateRefitClient<T>(Uri baseAddress);
public T CreateRefitClient<T>(Uri baseAddress, RefitSettings refitSettings);
}

14
StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs

@ -0,0 +1,14 @@
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Core.Attributes;
/// <summary>
/// Options for <see cref="ComfyTypedNodeBase{TOutput}"/>
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
public class TypedNodeOptionsAttribute : Attribute
{
public string? Name { get; init; }
public string[]? RequiredExtensions { get; init; }
}

49
StabilityMatrix.Core/Converters/Json/NodeConnectionBaseJsonConverter.cs

@ -0,0 +1,49 @@
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Core.Converters.Json;
public class NodeConnectionBaseJsonConverter : JsonConverter<NodeConnectionBase>
{
/// <inheritdoc />
public override NodeConnectionBase Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options
)
{
// Read as Data array
reader.Read();
var data = new object[2];
reader.Read();
data[0] = reader.GetString() ?? throw new JsonException("Expected string for node name");
reader.Read();
data[1] = reader.GetInt32();
reader.Read();
if (Activator.CreateInstance(typeToConvert) is not NodeConnectionBase instance)
{
throw new JsonException($"Failed to create instance of {typeToConvert}");
}
var propertyInfo =
typeToConvert.GetProperty("Data", BindingFlags.Public | BindingFlags.Instance)
?? throw new JsonException($"Failed to get Data property of {typeToConvert}");
propertyInfo.SetValue(instance, data);
return instance;
}
/// <inheritdoc />
public override void Write(Utf8JsonWriter writer, NodeConnectionBase value, JsonSerializerOptions options)
{
// Write as Data array
writer.WriteStartArray();
writer.WriteStringValue(value.Data?[0] as string);
writer.WriteNumberValue((int)value.Data?[1]!);
writer.WriteEndArray();
}
}

63
StabilityMatrix.Core/Converters/Json/OneOfJsonConverter.cs

@ -0,0 +1,63 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using OneOf;
namespace StabilityMatrix.Core.Converters.Json;
public class OneOfJsonConverter<T1, T2> : JsonConverter<OneOf<T1, T2>>
{
/// <inheritdoc />
public override OneOf<T1, T2> Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options
)
{
// Not sure how else to do this without polymorphic type markers but that would not serialize into T1/T2
// So just try to deserialize T1, if it fails, try T2
Exception? t1Exception = null;
Exception? t2Exception = null;
try
{
if (JsonSerializer.Deserialize<T1>(ref reader, options) is { } t1)
{
return t1;
}
}
catch (JsonException e)
{
t1Exception = e;
}
try
{
if (JsonSerializer.Deserialize<T2>(ref reader, options) is { } t2)
{
return t2;
}
}
catch (JsonException e)
{
t2Exception = e;
}
throw new JsonException(
$"Failed to deserialize OneOf<{typeof(T1)}, {typeof(T2)}> as either {typeof(T1)} or {typeof(T2)}",
new AggregateException([t1Exception, t2Exception])
);
}
/// <inheritdoc />
public override void Write(Utf8JsonWriter writer, OneOf<T1, T2> value, JsonSerializerOptions options)
{
if (value.IsT0)
{
JsonSerializer.Serialize(writer, value.AsT0, options);
}
else
{
JsonSerializer.Serialize(writer, value.AsT1, options);
}
}
}

22
StabilityMatrix.Core/Helper/RemoteModels.cs

@ -132,8 +132,7 @@ public static class RemoteModels
}
};
private static Uri ControlNetRoot { get; } =
new("https://huggingface.co/lllyasviel/ControlNet/");
private static Uri ControlNetRoot { get; } = new("https://huggingface.co/lllyasviel/ControlNet/");
private static RemoteResource ControlNetCommon(string path, string sha256)
{
@ -170,4 +169,23 @@ public static class RemoteModels
public static IReadOnlyList<HybridModelFile> ControlNetModels { get; } =
ControlNets.Select(HybridModelFile.FromDownloadable).ToImmutableArray();
private static IEnumerable<RemoteResource> PromptExpansions =>
[
new RemoteResource
{
Url = new Uri("https://cdn.lykos.ai/models/GPT-Prompt-Expansion-Fooocus-v2.zip"),
HashSha256 = "82e69311787c0bb6736389710d80c0a2b653ed9bbe6ea6e70c6b90820fe42d88",
InfoUrl = new Uri("https://huggingface.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2"),
Author = "lllyasviel, LykosAI",
LicenseType = "GPLv3",
LicenseUrl = new Uri("https://github.com/lllyasviel/Fooocus/blob/main/LICENSE"),
ContextType = SharedFolderType.PromptExpansion,
AutoExtractArchive = true,
ExtractRelativePath = "GPT-Prompt-Expansion-Fooocus-v2"
}
];
public static IEnumerable<HybridModelFile> PromptExpansionModels =>
PromptExpansions.Select(HybridModelFile.FromDownloadable);
}

28
StabilityMatrix.Core/Inference/ComfyClient.cs

@ -6,10 +6,13 @@ using NLog;
using Polly.Contrib.WaitAndRetry;
using Refit;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Converters.Json;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using Websocket.Client;
@ -26,8 +29,16 @@ public class ComfyClient : InferenceClientBase
private readonly IComfyApi comfyApi;
private bool isDisposed;
private JsonSerializerOptions jsonSerializerOptions =
new() { PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower, };
private readonly JsonSerializerOptions jsonSerializerOptions =
new()
{
PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower,
Converters =
{
new NodeConnectionBaseJsonConverter(),
new OneOfJsonConverter<string, StringNodeConnection>()
}
};
// ReSharper disable once MemberCanBePrivate.Global
public string ClientId { get; } = Guid.NewGuid().ToString();
@ -39,6 +50,11 @@ public class ComfyClient : InferenceClientBase
/// </summary>
public DirectoryPath? LocalServerPath { get; set; }
/// <summary>
/// If available, the local server package pair
/// </summary>
public PackagePair? LocalServerPackage { get; set; }
/// <summary>
/// Path to the "output" folder from LocalServerPath
/// </summary>
@ -81,7 +97,13 @@ public class ComfyClient : InferenceClientBase
public ComfyClient(IApiFactory apiFactory, Uri baseAddress)
{
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(baseAddress);
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(
baseAddress,
new RefitSettings
{
ContentSerializer = new SystemTextJsonContentSerializer(jsonSerializerOptions),
}
);
BaseAddress = baseAddress;
// Setup websocket client

18
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnectionBase.cs

@ -1,12 +1,14 @@
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Converters.Json;
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
[JsonConverter(typeof(NodeConnectionBaseJsonConverter))]
public abstract class NodeConnectionBase
{
public object[]? Data { get; set; }
// Implicit conversion to object[]
public static implicit operator object[](NodeConnectionBase nodeConnection)
{
return nodeConnection.Data ?? Array.Empty<object>();
}
/// <summary>
/// Array data for the connection.
/// [(string) Node Name, (int) Connection Index]
/// </summary>
public object[]? Data { get; init; }
}

26
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnections.cs

@ -1,25 +1,27 @@
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
public class LatentNodeConnection : NodeConnectionBase { }
public class LatentNodeConnection : NodeConnectionBase;
public class VAENodeConnection : NodeConnectionBase { }
public class VAENodeConnection : NodeConnectionBase;
public class ImageNodeConnection : NodeConnectionBase { }
public class ImageNodeConnection : NodeConnectionBase;
public class ImageMaskConnection : NodeConnectionBase { }
public class ImageMaskConnection : NodeConnectionBase;
public class UpscaleModelNodeConnection : NodeConnectionBase { }
public class UpscaleModelNodeConnection : NodeConnectionBase;
public class ModelNodeConnection : NodeConnectionBase { }
public class ModelNodeConnection : NodeConnectionBase;
public class ConditioningNodeConnection : NodeConnectionBase { }
public class ConditioningNodeConnection : NodeConnectionBase;
public class ClipNodeConnection : NodeConnectionBase { }
public class ClipNodeConnection : NodeConnectionBase;
public class ControlNetNodeConnection : NodeConnectionBase { }
public class ControlNetNodeConnection : NodeConnectionBase;
public class ClipVisionNodeConnection : NodeConnectionBase { }
public class ClipVisionNodeConnection : NodeConnectionBase;
public class SamplerNodeConnection : NodeConnectionBase { }
public class SamplerNodeConnection : NodeConnectionBase;
public class SigmasNodeConnection : NodeConnectionBase { }
public class SigmasNodeConnection : NodeConnectionBase;
public class StringNodeConnection : NodeConnectionBase;

37
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -1,8 +1,8 @@
using System.ComponentModel.DataAnnotations;
using System.ComponentModel;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Drawing;
using System.Runtime.Serialization;
using System.Text.Json.Serialization;
using OneOf;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
@ -14,6 +14,7 @@ namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
/// Builder functions for comfy nodes
/// </summary>
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[Localizable(false)]
public class ComfyNodeBuilder
{
public NodeDictionary Nodes { get; } = new();
@ -258,20 +259,7 @@ public class ComfyNodeBuilder
public record CLIPTextEncode : ComfyTypedNodeBase<ConditioningNodeConnection>
{
public required ClipNodeConnection Clip { get; init; }
public required string Text { get; init; }
}
public static NamedComfyNode<ConditioningNodeConnection> ClipTextEncode(
string name,
ClipNodeConnection clip,
string text
)
{
return new NamedComfyNode<ConditioningNodeConnection>(name)
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?> { ["clip"] = clip.Data, ["text"] = text }
};
public required OneOf<string, StringNodeConnection> Text { get; init; }
}
public record LoadImage : ComfyTypedNodeBase<ImageNodeConnection, ImageMaskConnection>
@ -342,6 +330,18 @@ public class ComfyNodeBuilder
public required string Method { get; init; }
}
[TypedNodeOptions(
Name = "Inference_Core_PromptExpansion",
RequiredExtensions = ["https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes"]
)]
public record PromptExpansion : ComfyTypedNodeBase<StringNodeConnection>
{
public required string ModelName { get; init; }
public required OneOf<string, StringNodeConnection> Text { get; init; }
public required ulong Seed { get; init; }
public bool LogPrompt { get; init; }
}
public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
{
var name = GetUniqueName("VAEDecode");
@ -801,6 +801,9 @@ public class ComfyNodeBuilder
public int BatchSize { get; set; } = 1;
public int? BatchIndex { get; set; }
public OneOf<string, StringNodeConnection> PositivePrompt { get; set; }
public OneOf<string, StringNodeConnection> NegativePrompt { get; set; }
public ClipNodeConnection? BaseClip { get; set; }
public ClipVisionNodeConnection? BaseClipVision { get; set; }

25
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs

@ -1,4 +1,5 @@
using System.Reflection;
using System.ComponentModel;
using System.Reflection;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
@ -8,8 +9,28 @@ namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
public abstract record ComfyTypedNodeBase
{
protected virtual string ClassType => GetType().Name;
[Localizable(false)]
protected virtual string ClassType
{
get
{
var type = GetType();
// Use options name if available
if (type.GetCustomAttribute<TypedNodeOptionsAttribute>() is { } options)
{
if (!string.IsNullOrEmpty(options.Name))
{
return options.Name;
}
}
// Otherwise use class name
return type.Name;
}
}
[Localizable(false)]
[JsonIgnore]
public required string Name { get; init; }

33
StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs

@ -1,4 +1,9 @@
using StabilityMatrix.Core.Helper;
using System.ComponentModel;
using System.Reflection;
using System.Text.Json.Serialization;
using OneOf;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
@ -10,10 +15,16 @@ public class NodeDictionary : Dictionary<string, ComfyNode>
/// </summary>
private readonly Dictionary<string, int> _baseNameIndex = new();
/// <summary>
/// When inserting TypedNodes, this holds a mapping of ClassType to required extensions
/// </summary>
[JsonIgnore]
public Dictionary<string, string[]> ClassTypeRequiredExtensions { get; } = new();
/// <summary>
/// Finds a unique node name given a base name, by appending _2, _3, etc.
/// </summary>
public string GetUniqueName(string nameBase)
public string GetUniqueName([Localizable(false)] string nameBase)
{
if (_baseNameIndex.TryGetValue(nameBase, out var index))
{
@ -43,7 +54,19 @@ public class NodeDictionary : Dictionary<string, ComfyNode>
public TTypedNode AddTypedNode<TTypedNode>(TTypedNode node)
where TTypedNode : ComfyTypedNodeBase
{
Add(node.Name, node);
var namedNode = (NamedComfyNode)node;
Add(node.Name, namedNode);
// Check statically annotated stuff for TypedNodeOptionsAttribute
if (node.GetType().GetCustomAttribute<TypedNodeOptionsAttribute>() is { } options)
{
if (options.RequiredExtensions != null)
{
ClassTypeRequiredExtensions[namedNode.ClassType] = options.RequiredExtensions;
}
}
return node;
}
@ -62,6 +85,10 @@ public class NodeDictionary : Dictionary<string, ComfyNode>
{
node.Inputs[key] = connection.Data;
}
else if (input is IOneOf { Value: NodeConnectionBase oneOfConnection })
{
node.Inputs[key] = oneOfConnection.Data;
}
}
}
}

1
StabilityMatrix.Core/Models/HybridModelFile.cs

@ -71,6 +71,7 @@ public record HybridModelFile
if (
!fileName.Equals("diffusion_pytorch_model", StringComparison.OrdinalIgnoreCase)
&& !fileName.Equals("pytorch_model", StringComparison.OrdinalIgnoreCase)
&& !fileName.Equals("ip_adapter", StringComparison.OrdinalIgnoreCase)
)
{

10
StabilityMatrix.Core/Models/Packages/ComfyUI.cs

@ -68,6 +68,7 @@ public class ComfyUI(
[SharedFolderType.InvokeIpAdapters15] = new[] { "models/ipadapter/sd15" },
[SharedFolderType.InvokeIpAdaptersXl] = new[] { "models/ipadapter/sdxl" },
[SharedFolderType.T2IAdapter] = new[] { "models/controlnet/T2IAdapter" },
[SharedFolderType.PromptExpansion] = new[] { "models/prompt_expansion" }
};
public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders =>
@ -373,6 +374,7 @@ public class ComfyUI(
Path.Combine(modelsDir, "InvokeIpAdapters15"),
Path.Combine(modelsDir, "InvokeIpAdaptersXl")
);
nodeValue.Children["prompt_expansion"] = Path.Combine(modelsDir, "PromptExpansion");
}
else
{
@ -410,7 +412,8 @@ public class ComfyUI(
Path.Combine(modelsDir, "InvokeIpAdapters15"),
Path.Combine(modelsDir, "InvokeIpAdaptersXl")
)
}
},
{ "prompt_expansion", Path.Combine(modelsDir, "PromptExpansion") }
}
);
}
@ -476,9 +479,8 @@ public class ComfyUI(
public override IEnumerable<ExtensionManifest> DefaultManifests =>
[
new ExtensionManifest(
new Uri("https://cdn.jsdelivr.net/gh/ltdrdata/ComfyUI-Manager/custom-node-list.json")
)
"https://cdn.jsdelivr.net/gh/ltdrdata/ComfyUI-Manager/custom-node-list.json",
"https://cdn.jsdelivr.net/gh/LykosAI/ComfyUI-Extensions-Index/custom-node-list.json"
];
public override async Task<IEnumerable<PackageExtension>> GetManifestExtensionsAsync(

5
StabilityMatrix.Core/Models/Packages/Extensions/ExtensionManifest.cs

@ -1,3 +1,6 @@
namespace StabilityMatrix.Core.Models.Packages.Extensions;
public record ExtensionManifest(Uri Uri);
public record ExtensionManifest(Uri Uri)
{
public static implicit operator ExtensionManifest(string uri) => new(new Uri(uri, UriKind.Absolute));
}

78
StabilityMatrix.Core/Models/Packages/Extensions/GitPackageExtensionManager.cs

@ -1,4 +1,6 @@
using System.Text.RegularExpressions;
using KGySoft.CoreLibraries;
using Microsoft.Extensions.Caching.Memory;
using NLog;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
@ -8,11 +10,14 @@ using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Models.Packages.Extensions;
public abstract class GitPackageExtensionManager(IPrerequisiteHelper prerequisiteHelper)
public abstract partial class GitPackageExtensionManager(IPrerequisiteHelper prerequisiteHelper)
: IPackageExtensionManager
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
// Cache checks of installed extensions
private readonly MemoryCache installedExtensionsCache = new(new MemoryCacheOptions());
public abstract string RelativeInstallDirectory { get; }
public virtual IEnumerable<ExtensionManifest> DefaultManifests { get; } =
@ -128,6 +133,74 @@ public abstract class GitPackageExtensionManager(IPrerequisiteHelper prerequisit
return extensions;
}
/// <summary>
/// Like <see cref="GetInstalledExtensionsAsync"/>, but does not check git version and repository url.
/// </summary>
public virtual async Task<IEnumerable<InstalledPackageExtension>> GetInstalledExtensionsLiteAsync(
InstalledPackage installedPackage,
CancellationToken cancellationToken = default
)
{
if (installedPackage.FullPath is not { } packagePath)
{
return Enumerable.Empty<InstalledPackageExtension>();
}
var extensions = new List<InstalledPackageExtension>();
// Search for installed extensions in the package's index directories.
foreach (
var indexDirectory in IndexRelativeDirectories.Select(
path => new DirectoryPath(packagePath, path)
)
)
{
cancellationToken.ThrowIfCancellationRequested();
// Skip directory if not exists
if (!indexDirectory.Exists)
{
continue;
}
// Check subdirectories of the index directory
foreach (var subDirectory in indexDirectory.EnumerateDirectories())
{
cancellationToken.ThrowIfCancellationRequested();
// Skip if not valid git repository
if (!subDirectory.JoinDir(".git").Exists)
{
continue;
}
// Get remote url with manual parsing
string? remoteUrl = null;
var gitConfigPath = subDirectory.JoinDir(".git").JoinFile("config");
if (
gitConfigPath.Exists
&& await gitConfigPath.ReadAllTextAsync(cancellationToken).ConfigureAwait(false)
is { } gitConfigText
)
{
var pattern = GitConfigRemoteOriginUrlRegex();
var match = pattern.Match(gitConfigText);
if (match.Success)
{
remoteUrl = match.Groups[1].Value;
}
}
extensions.Add(
new InstalledPackageExtension { Paths = [subDirectory], GitRepositoryUrl = remoteUrl }
);
}
}
return extensions;
}
/// <inheritdoc />
public virtual async Task InstallExtensionAsync(
PackageExtension extension,
@ -242,4 +315,7 @@ public abstract class GitPackageExtensionManager(IPrerequisiteHelper prerequisit
progress?.Report(new ProgressReport(1f, message: "Uninstalled extension"));
}
[GeneratedRegex("""\[remote "origin"\][\s\S]*?url\s*=\s*(.+)""")]
private static partial Regex GitConfigRemoteOriginUrlRegex();
}

29
StabilityMatrix.Core/Models/Packages/Extensions/IPackageExtensionManager.cs

@ -52,6 +52,35 @@ public interface IPackageExtensionManager
return extensions;
}
/// <summary>
/// Get unique extensions from all provided manifests. As a mapping of their reference.
/// </summary>
async Task<IDictionary<string, PackageExtension>> GetManifestExtensionsMapAsync(
IEnumerable<ExtensionManifest> manifests,
CancellationToken cancellationToken = default
)
{
var result = new Dictionary<string, PackageExtension>();
foreach (
var extension in await GetManifestExtensionsAsync(manifests, cancellationToken)
.ConfigureAwait(false)
)
{
cancellationToken.ThrowIfCancellationRequested();
var key = extension.Reference.ToString();
if (!result.TryAdd(key, extension))
{
// Replace
result[key] = extension;
}
}
return result;
}
/// <summary>
/// Get all installed extensions for the provided package.
/// </summary>

2
StabilityMatrix.Core/Models/SharedFolderType.cs

@ -39,4 +39,6 @@ public enum SharedFolderType
InvokeIpAdapters15 = 1 << 24,
InvokeIpAdaptersXl = 1 << 25,
InvokeClipVision = 1 << 26,
PromptExpansion = 1 << 30
}

2
StabilityMatrix.Core/Models/TrackedDownload.cs

@ -124,6 +124,8 @@ public class TrackedDownload
{
var progress = new Progress<ProgressReport>(OnProgressUpdate);
DownloadDirectory.Create();
await downloadService!
.ResumeDownloadToFileAsync(
SourceUrl.ToString(),

Loading…
Cancel
Save