Browse Source

Merge pull request #531 from ionite34/prompt-expansion

Prompt expansion
pull/495/head
Ionite 9 months ago committed by GitHub
parent
commit
75cdf9b673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      CHANGELOG.md
  2. 1
      StabilityMatrix.Avalonia/App.axaml
  3. 50
      StabilityMatrix.Avalonia/Controls/Inference/PromptCard.axaml
  4. 96
      StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml
  5. 52
      StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml.cs
  6. 18
      StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
  7. 8
      StabilityMatrix.Avalonia/Models/Inference/PromptCardModel.cs
  8. 11
      StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
  9. 40
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  10. 1
      StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
  11. 131
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  12. 14
      StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs
  13. 1
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs
  14. 1
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  15. 40
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/PromptExpansionModule.cs
  16. 50
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  17. 47
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptExpansionCardViewModel.cs
  18. 12
      StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs
  19. 2
      StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
  20. 8
      StabilityMatrix.Core/Api/ApiFactory.cs
  21. 6
      StabilityMatrix.Core/Api/IApiFactory.cs
  22. 14
      StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs
  23. 49
      StabilityMatrix.Core/Converters/Json/NodeConnectionBaseJsonConverter.cs
  24. 63
      StabilityMatrix.Core/Converters/Json/OneOfJsonConverter.cs
  25. 22
      StabilityMatrix.Core/Helper/RemoteModels.cs
  26. 28
      StabilityMatrix.Core/Inference/ComfyClient.cs
  27. 18
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnectionBase.cs
  28. 26
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnections.cs
  29. 37
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  30. 25
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs
  31. 33
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs
  32. 1
      StabilityMatrix.Core/Models/HybridModelFile.cs
  33. 10
      StabilityMatrix.Core/Models/Packages/ComfyUI.cs
  34. 5
      StabilityMatrix.Core/Models/Packages/Extensions/ExtensionManifest.cs
  35. 78
      StabilityMatrix.Core/Models/Packages/Extensions/GitPackageExtensionManager.cs
  36. 29
      StabilityMatrix.Core/Models/Packages/Extensions/IPackageExtensionManager.cs
  37. 2
      StabilityMatrix.Core/Models/SharedFolderType.cs
  38. 2
      StabilityMatrix.Core/Models/TrackedDownload.cs

1
CHANGELOG.md

@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
## v2.9.0-dev.3 ## v2.9.0-dev.3
### Added ### Added
- Added Inference Prompt Styles, with Prompt Expansion model support (i.e. Fooocus V2)
- Added copy image support on linux and macOS for Inference outputs viewer menu - Added copy image support on linux and macOS for Inference outputs viewer menu
### Fixed ### Fixed
- Fixed StableSwarmUI not installing properly on macOS - Fixed StableSwarmUI not installing properly on macOS

1
StabilityMatrix.Avalonia/App.axaml

@ -78,6 +78,7 @@
<StyleInclude Source="Controls/Inference/SharpenCard.axaml"/> <StyleInclude Source="Controls/Inference/SharpenCard.axaml"/>
<StyleInclude Source="Controls/Inference/FreeUCard.axaml"/> <StyleInclude Source="Controls/Inference/FreeUCard.axaml"/>
<StyleInclude Source="Controls/Inference/ControlNetCard.axaml"/> <StyleInclude Source="Controls/Inference/ControlNetCard.axaml"/>
<StyleInclude Source="Controls/Inference/PromptExpansionCard.axaml"/>
<Style Selector="DockControl"> <Style Selector="DockControl">
<Setter Property="(DockProperties.ControlRecycling)" Value="{StaticResource ControlRecyclingKey}" /> <Setter Property="(DockProperties.ControlRecycling)" Value="{StaticResource ControlRecyclingKey}" />

50
StabilityMatrix.Avalonia/Controls/Inference/PromptCard.axaml

@ -33,7 +33,7 @@
</Style> </Style>
</controls:Card.Styles> </controls:Card.Styles>
<Grid RowDefinitions="*,16,*"> <Grid RowDefinitions="*,16,*,16,Auto">
<!-- Prompt --> <!-- Prompt -->
<Grid ColumnDefinitions="*,Auto" RowDefinitions="Auto,*"> <Grid ColumnDefinitions="*,Auto" RowDefinitions="Auto,*">
<StackPanel <StackPanel
@ -137,6 +137,54 @@
</Border> </Border>
</Grid> </Grid>
<GridSplitter
Grid.Row="3"
MaxWidth="45"
VerticalAlignment="Center"
BorderThickness="1"
CornerRadius="4"
Opacity="0.3" />
<controls:StackEditableCard
Margin="2,0,0,0"
DataContext="{Binding ModulesCardViewModel}"
Grid.Row="4">
</controls:StackEditableCard>
<!-- Styles and Prompt Expansions -->
<!--<Grid Grid.Row="4" RowDefinitions="Auto,*">
<StackPanel Margin="4,0,4,8" Orientation="Horizontal">
<TextBlock FontSize="14" Text="Styles" />
<icons:Icon
Margin="8,0"
FontSize="10"
Value="fa-solid fa-caret-down" />
</StackPanel>
<Border
Grid.Row="1"
Classes="theme-dark"
VerticalAlignment="Stretch"
CornerRadius="4">
<avaloniaEdit:TextEditor
x:Name="ExtraPromptEditor"
Document="{Binding NegativePromptDocument}"
FontFamily="Cascadia Code,Consolas,Menlo,Monospace">
<i:Interaction.Behaviors>
<behaviors:TextEditorCompletionBehavior
CompletionProvider="{Binding CompletionProvider}"
IsEnabled="{Binding IsAutoCompletionEnabled}"
TokenizerProvider="{Binding TokenizerProvider}" />
<behaviors:TextEditorToolTipBehavior IsEnabled="False" TokenizerProvider="{Binding TokenizerProvider}" />
</i:Interaction.Behaviors>
</avaloniaEdit:TextEditor>
</Border>
</Grid>-->
</Grid> </Grid>
</controls:Card> </controls:Card>
</ControlTemplate> </ControlTemplate>

96
StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml

@ -0,0 +1,96 @@
<Styles
xmlns="https://github.com/avaloniaui"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
xmlns:converters="clr-namespace:StabilityMatrix.Avalonia.Converters"
xmlns:inference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData"
xmlns:models="clr-namespace:StabilityMatrix.Core.Models;assembly=StabilityMatrix.Core"
xmlns:sg="clr-namespace:SpacedGridControl.Avalonia;assembly=SpacedGridControl.Avalonia"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:input="clr-namespace:FluentAvalonia.UI.Input;assembly=FluentAvalonia"
xmlns:fluentIcons="clr-namespace:FluentIcons.Avalonia.Fluent;assembly=FluentIcons.Avalonia.Fluent"
xmlns:local="clr-namespace:StabilityMatrix.Avalonia"
x:DataType="inference:PromptExpansionCardViewModel">
<Design.PreviewWith>
<Panel Width="400" Height="200">
<StackPanel Width="300" VerticalAlignment="Center">
<controls:PromptExpansionCard />
</StackPanel>
</Panel>
</Design.PreviewWith>
<Style Selector="controls|PromptExpansionCard">
<Setter Property="HorizontalAlignment" Value="Stretch" />
<Setter Property="Template">
<ControlTemplate>
<controls:Card Padding="12">
<sg:SpacedGrid
ColumnDefinitions="Auto,*"
ColumnSpacing="8"
RowDefinitions="*,*,*,*"
RowSpacing="0">
<!-- Model -->
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="{x:Static lang:Resources.Label_Model}"
TextAlignment="Left" />
<ui:FAComboBox
x:Name="PART_ModelComboBox"
Grid.Row="0"
Grid.Column="1"
HorizontalAlignment="Stretch"
ItemContainerTheme="{StaticResource FAComboBoxItemHybridModelTheme}"
ItemsSource="{Binding ClientManager.PromptExpansionModels}"
SelectedItem="{Binding SelectedModel}">
<ui:FAComboBox.Resources>
<input:StandardUICommand x:Key="RemoteDownloadCommand"
Command="{Binding RemoteDownloadCommand}" />
</ui:FAComboBox.Resources>
<ui:FAComboBox.DataTemplates>
<controls:HybridModelTemplateSelector>
<DataTemplate x:Key="{x:Static models:HybridModelType.Downloadable}" DataType="models:HybridModelFile">
<Grid ColumnDefinitions="*,Auto">
<TextBlock Foreground="{DynamicResource ThemeGreyColor}" Text="{Binding ShortDisplayName}" />
<Button
Grid.Column="1"
Margin="8,0,0,0"
Padding="0"
Classes="transparent-full">
<fluentIcons:SymbolIcon
VerticalAlignment="Center"
FontSize="18"
Foreground="{DynamicResource ThemeGreyColor}"
IsFilled="True"
Symbol="CloudArrowDown" />
</Button>
</Grid>
</DataTemplate>
<DataTemplate x:Key="{x:Static models:HybridModelType.None}" DataType="models:HybridModelFile">
<TextBlock Text="{Binding ShortDisplayName}" />
</DataTemplate>
</controls:HybridModelTemplateSelector>
</ui:FAComboBox.DataTemplates>
</ui:FAComboBox>
<!--<controls:BetterComboBox
Grid.Row="0"
Grid.Column="1"
Padding="8,6,4,6"
HorizontalAlignment="Stretch"
ItemsSource="{Binding ClientManager.Upscalers}"
SelectedItem="{Binding SelectedModel}"
Theme="{StaticResource BetterComboBoxHybridModelTheme}" />-->
</sg:SpacedGrid>
</controls:Card>
</ControlTemplate>
</Setter>
</Style>
</Styles>

52
StabilityMatrix.Avalonia/Controls/Inference/PromptExpansionCard.axaml.cs

@ -0,0 +1,52 @@
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Controls;
[Transient]
public class PromptExpansionCard : TemplatedControl
{
/// <inheritdoc />
protected override void OnApplyTemplate(TemplateAppliedEventArgs e)
{
base.OnApplyTemplate(e);
var upscalerComboBox = e.NameScope.Find("PART_ModelComboBox") as FAComboBox;
upscalerComboBox!.SelectionChanged += UpscalerComboBox_OnSelectionChanged;
}
private void UpscalerComboBox_OnSelectionChanged(object? sender, SelectionChangedEventArgs e)
{
if (e.AddedItems.Count == 0)
return;
var item = e.AddedItems[0];
if (item is HybridModelFile { IsDownloadable: true })
{
// Reset the selection
e.Handled = true;
if (
e.RemovedItems.Count > 0
&& e.RemovedItems[0] is HybridModelFile { IsDownloadable: false } removedItem
)
{
(sender as FAComboBox)!.SelectedItem = removedItem;
}
else
{
(sender as FAComboBox)!.SelectedItem = null;
}
// Show dialog to download the model
(DataContext as PromptExpansionCardViewModel)!
.RemoteDownloadCommand.ExecuteAsync(item)
.SafeFireAndForget();
}
}
}

18
StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

@ -27,6 +27,9 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
public IObservableCollection<HybridModelFile> ControlNetModels { get; } = public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>(); new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<ComfySampler> Samplers { get; } = public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults); new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults);
@ -64,28 +67,19 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
} }
/// <inheritdoc /> /// <inheritdoc />
public Task CopyImageToInputAsync( public Task CopyImageToInputAsync(FilePath imageFile, CancellationToken cancellationToken = default)
FilePath imageFile,
CancellationToken cancellationToken = default
)
{ {
return Task.CompletedTask; return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
public Task UploadInputImageAsync( public Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default)
ImageSource image,
CancellationToken cancellationToken = default
)
{ {
return Task.CompletedTask; return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
public Task WriteImageToInputAsync( public Task WriteImageToInputAsync(ImageSource imageSource, CancellationToken cancellationToken = default)
ImageSource imageSource,
CancellationToken cancellationToken = default
)
{ {
return Task.CompletedTask; return Task.CompletedTask;
} }

8
StabilityMatrix.Avalonia/Models/Inference/PromptCardModel.cs

@ -1,10 +1,10 @@
using System.Text.Json.Serialization; using System.Text.Json.Nodes;
namespace StabilityMatrix.Avalonia.Models.Inference; namespace StabilityMatrix.Avalonia.Models.Inference;
[JsonSerializable(typeof(PromptCardModel))]
public class PromptCardModel public class PromptCardModel
{ {
public string? Prompt { get; set; } public string? Prompt { get; init; }
public string? NegativePrompt { get; set; } public string? NegativePrompt { get; init; }
public JsonObject? ModulesCardState { get; init; }
} }

11
StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

@ -12,10 +12,7 @@ using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Avalonia.Services; namespace StabilityMatrix.Avalonia.Services;
public interface IInferenceClientManager public interface IInferenceClientManager : IDisposable, INotifyPropertyChanged, INotifyPropertyChanging
: IDisposable,
INotifyPropertyChanged,
INotifyPropertyChanging
{ {
ComfyClient? Client { get; set; } ComfyClient? Client { get; set; }
@ -43,6 +40,7 @@ public interface IInferenceClientManager
IObservableCollection<HybridModelFile> Models { get; } IObservableCollection<HybridModelFile> Models { get; }
IObservableCollection<HybridModelFile> VaeModels { get; } IObservableCollection<HybridModelFile> VaeModels { get; }
IObservableCollection<HybridModelFile> ControlNetModels { get; } IObservableCollection<HybridModelFile> ControlNetModels { get; }
IObservableCollection<HybridModelFile> PromptExpansionModels { get; }
IObservableCollection<ComfySampler> Samplers { get; } IObservableCollection<ComfySampler> Samplers { get; }
IObservableCollection<ComfyUpscaler> Upscalers { get; } IObservableCollection<ComfyUpscaler> Upscalers { get; }
IObservableCollection<ComfyScheduler> Schedulers { get; } IObservableCollection<ComfyScheduler> Schedulers { get; }
@ -51,10 +49,7 @@ public interface IInferenceClientManager
Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default); Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default);
Task WriteImageToInputAsync( Task WriteImageToInputAsync(ImageSource imageSource, CancellationToken cancellationToken = default);
ImageSource imageSource,
CancellationToken cancellationToken = default
);
Task ConnectAsync(CancellationToken cancellationToken = default); Task ConnectAsync(CancellationToken cancellationToken = default);

40
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -74,6 +74,14 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
public IObservableCollection<HybridModelFile> ControlNetModels { get; } = public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>(); new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<HybridModelFile, string> promptExpansionModelsSource = new(p => p.GetId());
private readonly SourceCache<HybridModelFile, string> downloadablePromptExpansionModelsSource =
new(p => p.GetId());
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name); private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name);
public IObservableCollection<ComfySampler> Samplers { get; } = public IObservableCollection<ComfySampler> Samplers { get; } =
@ -130,6 +138,18 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Bind(ControlNetModels) .Bind(ControlNetModels)
.Subscribe(); .Subscribe();
promptExpansionModelsSource
.Connect()
.Or(downloadablePromptExpansionModelsSource.Connect())
.Sort(
SortExpressionComparer<HybridModelFile>
.Ascending(f => f.Type)
.ThenByAscending(f => f.ShortDisplayName)
)
.DeferUntilLoaded()
.Bind(PromptExpansionModels)
.Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default); vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe(); vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe();
@ -199,6 +219,8 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
); );
} }
// Prompt Expansion indexing is local only
// Fetch sampler names from KSampler node // Fetch sampler names from KSampler node
if (await Client.GetSamplerNamesAsync() is { } samplerNames) if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{ {
@ -277,6 +299,22 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
); );
downloadableControlNetModelsSource.EditDiff(downloadableControlNets, HybridModelFile.Comparer); downloadableControlNetModelsSource.EditDiff(downloadableControlNets, HybridModelFile.Comparer);
// Load local prompt expansion models
promptExpansionModelsSource.EditDiff(
modelIndexService
.GetFromModelIndex(SharedFolderType.PromptExpansion)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
// Downloadable PromptExpansion models
downloadablePromptExpansionModelsSource.EditDiff(
RemoteModels.PromptExpansionModels.Where(
u => !promptExpansionModelsSource.Lookup(u.GetId()).HasValue
),
HybridModelFile.Comparer
);
// Load local VAE models // Load local VAE models
vaeModelsSource.EditDiff( vaeModelsSource.EditDiff(
modelIndexService.GetFromModelIndex(SharedFolderType.VAE).Select(HybridModelFile.FromLocal), modelIndexService.GetFromModelIndex(SharedFolderType.VAE).Select(HybridModelFile.FromLocal),
@ -481,7 +519,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
await ConnectAsyncImpl(uri, cancellationToken); await ConnectAsyncImpl(uri, cancellationToken);
// Set package path as server path Client.LocalServerPackage = packagePair;
Client.LocalServerPath = packagePair.InstalledPackage.FullPath!; Client.LocalServerPath = packagePair.InstalledPackage.FullPath!;
} }

1
StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

@ -83,6 +83,7 @@
<PackageReference Include="NLog" Version="5.2.8" /> <PackageReference Include="NLog" Version="5.2.8" />
<PackageReference Include="NLog.Extensions.Logging" Version="5.3.8" /> <PackageReference Include="NLog.Extensions.Logging" Version="5.3.8" />
<PackageReference Include="NSubstitute" Version="5.1.0" /> <PackageReference Include="NSubstitute" Version="5.1.0" />
<PackageReference Include="OneOf" Version="3.0.263" />
<PackageReference Include="Polly" Version="8.2.1" /> <PackageReference Include="Polly" Version="8.2.1" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" /> <PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="Polly.Extensions.Http" Version="3.0.0" /> <PackageReference Include="Polly.Extensions.Http" Version="3.0.0" />

131
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -15,11 +15,15 @@ using Avalonia.Controls.Notifications;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using ExifLibrary; using ExifLibrary;
using FluentAvalonia.UI.Controls;
using Microsoft.Extensions.DependencyInjection;
using Nito.Disposables.Internals;
using NLog; using NLog;
using Refit; using Refit;
using SkiaSharp; using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
@ -35,6 +39,8 @@ using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Packages.Extensions;
using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Models.Settings;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
using Notification = DesktopNotifications.Notification; using Notification = DesktopNotifications.Notification;
@ -272,6 +278,15 @@ public abstract partial class InferenceGenerationViewModelBase
if (client.OutputImagesDir is null) if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null"); throw new InvalidOperationException("OutputImagesDir is null");
// Only check extensions for first batch index
if (args.BatchIndex == 0)
{
if (!await CheckPromptExtensionsInstalled(args.Nodes))
{
throw new ValidationException("Prompt extensions not installed");
}
}
// Upload input images // Upload input images
await UploadInputImages(client); await UploadInputImages(client);
@ -621,6 +636,121 @@ public abstract partial class InferenceGenerationViewModelBase
return ClientManager.IsConnected; return ClientManager.IsConnected;
} }
/// <summary>
/// Shows a dialog and return false if prompt required extensions not installed
/// </summary>
private async Task<bool> CheckPromptExtensionsInstalled(NodeDictionary nodeDictionary)
{
// Get prompt required extensions
// Just static for now but could do manifest lookup when we support custom workflows
var requiredExtensions = nodeDictionary
.ClassTypeRequiredExtensions.Values.SelectMany(x => x)
.ToHashSet();
// Skip if no extensions required
if (requiredExtensions.Count == 0)
{
return true;
}
// Get installed extensions
var localPackagePair = ClientManager.Client?.LocalServerPackage.Unwrap()!;
var manager = localPackagePair.BasePackage.ExtensionManager.Unwrap();
var localExtensions = (
await ((GitPackageExtensionManager)manager).GetInstalledExtensionsLiteAsync(
localPackagePair.InstalledPackage
)
).ToImmutableArray();
var missingExtensions = requiredExtensions
.Except(localExtensions.Select(ext => ext.GitRepositoryUrl).WhereNotNull())
.ToImmutableArray();
if (missingExtensions.Length == 0)
{
return true;
}
var dialog = DialogHelper.CreateMarkdownDialog(
$"#### The following extensions are required for this workflow:\n"
+ $"{string.Join("\n- ", missingExtensions)}",
"Install Required Extensions?"
);
dialog.IsPrimaryButtonEnabled = true;
dialog.DefaultButton = ContentDialogButton.Primary;
dialog.PrimaryButtonText =
$"{Resources.Action_Install} ({localPackagePair.InstalledPackage.DisplayName.ToRepr()} will restart)";
dialog.CloseButtonText = Resources.Action_Cancel;
if (await dialog.ShowAsync() == ContentDialogResult.Primary)
{
var manifestExtensionsMap = await manager.GetManifestExtensionsMapAsync(
manager.GetManifests(localPackagePair.InstalledPackage)
);
var steps = new List<IPackageStep>();
foreach (var missingExtensionUrl in missingExtensions)
{
if (!manifestExtensionsMap.TryGetValue(missingExtensionUrl, out var extension))
{
Logger.Warn(
"Extension {MissingExtensionUrl} not found in manifests",
missingExtensionUrl
);
continue;
}
steps.Add(new InstallExtensionStep(manager, localPackagePair.InstalledPackage, extension));
}
var runner = new PackageModificationRunner
{
ShowDialogOnStart = true,
ModificationCompleteTitle = "Extensions Installed",
ModificationCompleteMessage = "Finished installing required extensions"
};
EventManager.Instance.OnPackageInstallProgressAdded(runner);
runner
.ExecuteSteps(steps)
.ContinueWith(async _ =>
{
if (runner.Failed)
return;
// Restart Package
// TODO: This should be handled by some DI package manager service
var launchPage = App.Services.GetRequiredService<LaunchPageViewModel>();
try
{
await Dispatcher.UIThread.InvokeAsync(async () =>
{
await launchPage.Stop();
await launchPage.LaunchAsync();
});
}
catch (Exception e)
{
Logger.Error(e, "Error while restarting package");
notificationService.ShowPersistent(
new AppException(
"Could not restart package",
"Please manually restart the package for extension changes to take effect"
)
);
}
})
.SafeFireAndForget();
}
return false;
}
/// <summary> /// <summary>
/// Handles the preview image received event from the websocket. /// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery. /// Updates the preview image in the image gallery.
@ -683,6 +813,7 @@ public abstract partial class InferenceGenerationViewModelBase
public required ComfyClient Client { get; init; } public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; } public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList<string> OutputNodeNames { get; init; } public required IReadOnlyList<string> OutputNodeNames { get; init; }
public int BatchIndex { get; init; }
public GenerationParameters? Parameters { get; init; } public GenerationParameters? Parameters { get; init; }
public InferenceProjectDocument? Project { get; init; } public InferenceProjectDocument? Project { get; init; }
public bool ClearOutputImages { get; init; } = true; public bool ClearOutputImages { get; init; } = true;

14
StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs

@ -19,11 +19,13 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)] [JsonDerivedType(typeof(FreeUCardViewModel), FreeUCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)] [JsonDerivedType(typeof(UpscalerCardViewModel), UpscalerCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(ControlNetCardViewModel), ControlNetCardViewModel.ModuleKey)] [JsonDerivedType(typeof(ControlNetCardViewModel), ControlNetCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(PromptExpansionCardViewModel), PromptExpansionCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(FreeUModule))] [JsonDerivedType(typeof(FreeUModule))]
[JsonDerivedType(typeof(HiresFixModule))] [JsonDerivedType(typeof(HiresFixModule))]
[JsonDerivedType(typeof(UpscalerModule))] [JsonDerivedType(typeof(UpscalerModule))]
[JsonDerivedType(typeof(ControlNetModule))] [JsonDerivedType(typeof(ControlNetModule))]
[JsonDerivedType(typeof(SaveImageModule))] [JsonDerivedType(typeof(SaveImageModule))]
[JsonDerivedType(typeof(PromptExpansionModule))]
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -32,7 +34,8 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) }; private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) };
private static readonly JsonSerializerOptions SerializerOptions = new() { IgnoreReadOnlyProperties = true }; private static readonly JsonSerializerOptions SerializerOptions =
new() { IgnoreReadOnlyProperties = true };
private static bool ShouldIgnoreProperty(PropertyInfo property) private static bool ShouldIgnoreProperty(PropertyInfo property)
{ {
@ -243,7 +246,11 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
} }
else else
{ {
Logger.ConditionalTrace("Serializing {Property} ({Type})", property.Name, property.PropertyType); Logger.ConditionalTrace(
"Serializing {Property} ({Type})",
property.Name,
property.PropertyType
);
var value = property.GetValue(this); var value = property.GetValue(this);
if (value is not null) if (value is not null)
{ {
@ -266,7 +273,8 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
protected static JsonObject SerializeModel<T>(T model) protected static JsonObject SerializeModel<T>(T model)
{ {
var node = JsonSerializer.SerializeToNode(model); var node = JsonSerializer.SerializeToNode(model);
return node?.AsObject() ?? throw new NullReferenceException("Failed to serialize state to JSON object."); return node?.AsObject()
?? throw new NullReferenceException("Failed to serialize state to JSON object.");
} }
/// <summary> /// <summary>

1
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs

@ -200,6 +200,7 @@ public partial class InferenceImageToVideoViewModel
Parameters = SaveStateToParameters(new GenerationParameters()), Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this), Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer, FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch // Only clear output images on the first batch
ClearOutputImages = i == 0 ClearOutputImages = i == 0
}; };

1
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -214,6 +214,7 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, I
Parameters = SaveStateToParameters(new GenerationParameters()), Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this), Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer, FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch // Only clear output images on the first batch
ClearOutputImages = i == 0 ClearOutputImages = i == 0
}; };

40
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/PromptExpansionModule.cs

@ -0,0 +1,40 @@
using System;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
[ManagedService]
[Transient]
public class PromptExpansionModule : ModuleBase
{
public PromptExpansionModule(ServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
Title = "Prompt Expansion";
AddCards(vmFactory.Get<PromptExpansionCardViewModel>());
}
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var promptExpansionCard = GetCard<PromptExpansionCardViewModel>();
var model =
promptExpansionCard.SelectedModel
?? throw new InvalidOperationException($"{Title}: Model not selected");
e.Builder.Connections.PositivePrompt = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.PromptExpansion
{
Name = e.Nodes.GetUniqueName("PromptExpansion_Positive"),
ModelName = model.RelativePath,
Text = e.Builder.Connections.PositivePrompt,
Seed = e.Builder.Connections.Seed,
LogPrompt = promptExpansionCard.IsLogOutputEnabled
}
).Output;
}
}

50
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,5 +1,4 @@
using System; using System.Linq;
using System.Linq;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
@ -13,12 +12,13 @@ using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
@ -43,6 +43,8 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
public TextDocument PromptDocument { get; } = new(); public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new(); public TextDocument NegativePromptDocument { get; } = new();
public StackEditableCardViewModel ModulesCardViewModel { get; }
[ObservableProperty] [ObservableProperty]
private bool isAutoCompletionEnabled; private bool isAutoCompletionEnabled;
@ -52,6 +54,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
ITokenizerProvider tokenizerProvider, ITokenizerProvider tokenizerProvider,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IModelIndexService modelIndexService, IModelIndexService modelIndexService,
ServiceManager<ViewModelBase> vmFactory,
SharedState sharedState SharedState sharedState
) )
{ {
@ -60,6 +63,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
TokenizerProvider = tokenizerProvider; TokenizerProvider = tokenizerProvider;
SharedState = sharedState; SharedState = sharedState;
ModulesCardViewModel = vmFactory.Get<StackEditableCardViewModel>(vm =>
{
vm.Title = "Styles";
vm.AvailableModules = [typeof(PromptExpansionModule)];
});
settingsManager.RelayPropertyFor( settingsManager.RelayPropertyFor(
this, this,
vm => vm.IsAutoCompletionEnabled, vm => vm.IsAutoCompletionEnabled,
@ -84,8 +93,14 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
// Load prompts // Load prompts
var positivePrompt = GetPrompt(); var positivePrompt = GetPrompt();
positivePrompt.Process(); positivePrompt.Process();
e.Builder.Connections.PositivePrompt = positivePrompt.ProcessedText;
var negativePrompt = GetNegativePrompt(); var negativePrompt = GetNegativePrompt();
negativePrompt.Process(); negativePrompt.Process();
e.Builder.Connections.NegativePrompt = negativePrompt.ProcessedText;
// Apply modules / styles that may modify the prompt
ModulesCardViewModel.ApplyStep(e);
foreach (var modelConnections in e.Builder.Connections.Models.Values) foreach (var modelConnections in e.Builder.Connections.Models.Values)
{ {
@ -98,7 +113,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList(); var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList();
// Add group to load loras onto model and clip in series // Add group to load loras onto model and clip in series
var lorasGroup = e.Builder.Group_LoraLoadMany($"Loras_{modelConnections.Name}", model, clip, loras); var lorasGroup = e.Builder.Group_LoraLoadMany(
$"Loras_{modelConnections.Name}",
model,
clip,
loras
);
// Set last outputs as model and clip // Set last outputs as model and clip
modelConnections.Model = lorasGroup.Output1; modelConnections.Model = lorasGroup.Output1;
@ -111,7 +131,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
{ {
Name = $"PositiveCLIP_{modelConnections.Name}", Name = $"PositiveCLIP_{modelConnections.Name}",
Clip = e.Builder.Connections.Base.Clip!, Clip = e.Builder.Connections.Base.Clip!,
Text = positivePrompt.ProcessedText Text = e.Builder.Connections.PositivePrompt
} }
); );
var negativeClip = e.Nodes.AddTypedNode( var negativeClip = e.Nodes.AddTypedNode(
@ -119,7 +139,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
{ {
Name = $"NegativeCLIP_{modelConnections.Name}", Name = $"NegativeCLIP_{modelConnections.Name}",
Clip = e.Builder.Connections.Base.Clip!, Clip = e.Builder.Connections.Base.Clip!,
Text = negativePrompt.ProcessedText Text = e.Builder.Connections.NegativePrompt
} }
); );
@ -319,7 +339,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
public override JsonObject SaveStateToJsonObject() public override JsonObject SaveStateToJsonObject()
{ {
return SerializeModel( return SerializeModel(
new PromptCardModel { Prompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text } new PromptCardModel
{
Prompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text,
ModulesCardState = ModulesCardViewModel.SaveStateToJsonObject()
}
); );
} }
@ -330,6 +355,11 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
PromptDocument.Text = model.Prompt ?? ""; PromptDocument.Text = model.Prompt ?? "";
NegativePromptDocument.Text = model.NegativePrompt ?? ""; NegativePromptDocument.Text = model.NegativePrompt ?? "";
if (model.ModulesCardState is not null)
{
ModulesCardViewModel.LoadStateFromJsonObject(model.ModulesCardState);
}
} }
/// <inheritdoc /> /// <inheritdoc />
@ -342,6 +372,10 @@ public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoa
/// <inheritdoc /> /// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters) public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{ {
return parameters with { PositivePrompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text }; return parameters with
{
PositivePrompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text
};
} }
} }

47
StabilityMatrix.Avalonia/ViewModels/Inference/PromptExpansionCardViewModel.cs

@ -0,0 +1,47 @@
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptExpansionCard))]
[ManagedService]
[Transient]
public partial class PromptExpansionCardViewModel(
IInferenceClientManager clientManager,
ServiceManager<ViewModelBase> vmFactory
) : LoadableViewModelBase
{
public const string ModuleKey = "PromptExpansion";
public IInferenceClientManager ClientManager { get; } = clientManager;
[ObservableProperty]
private HybridModelFile? selectedModel;
[ObservableProperty]
private bool isLogOutputEnabled = true;
[RelayCommand]
private async Task RemoteDownload(HybridModelFile? model)
{
if (model?.DownloadableResource is not { } resource)
return;
var confirmDialog = vmFactory.Get<DownloadResourceViewModel>();
confirmDialog.Resource = resource;
confirmDialog.FileName = resource.FileName;
if (await confirmDialog.GetDialog().ShowAsync() == ContentDialogResult.Primary)
{
confirmDialog.StartDownload();
}
}
}

12
StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs

@ -5,6 +5,7 @@ using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
@ -16,7 +17,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(StackEditableCard))] [View(typeof(StackEditableCard))]
[ManagedService] [ManagedService]
[Transient] [Transient]
public partial class StackEditableCardViewModel : StackViewModelBase public partial class StackEditableCardViewModel : StackViewModelBase, IComfyStep
{ {
private readonly ServiceManager<ViewModelBase> vmFactory; private readonly ServiceManager<ViewModelBase> vmFactory;
@ -68,6 +69,15 @@ public partial class StackEditableCardViewModel : StackViewModelBase
} }
} }
/// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e)
{
foreach (var module in Cards.OfType<IComfyStep>())
{
module.ApplyStep(e);
}
}
/// <inheritdoc /> /// <inheritdoc />
protected override void OnCardAdded(LoadableViewModelBase item) protected override void OnCardAdded(LoadableViewModelBase item)
{ {

2
StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs

@ -206,7 +206,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
} }
[RelayCommand] [RelayCommand]
private async Task LaunchAsync(string? command = null) public async Task LaunchAsync(string? command = null)
{ {
await notificationService.TryAsync(LaunchImpl(command)); await notificationService.TryAsync(LaunchImpl(command));
} }

8
StabilityMatrix.Core/Api/ApiFactory.cs

@ -18,4 +18,12 @@ public class ApiFactory : IApiFactory
httpClient.BaseAddress = baseAddress; httpClient.BaseAddress = baseAddress;
return RestService.For<T>(httpClient, RefitSettings); return RestService.For<T>(httpClient, RefitSettings);
} }
public T CreateRefitClient<T>(Uri baseAddress, RefitSettings refitSettings)
{
var httpClient = httpClientFactory.CreateClient(nameof(T));
httpClient.BaseAddress = baseAddress;
return RestService.For<T>(httpClient, refitSettings);
}
} }

6
StabilityMatrix.Core/Api/IApiFactory.cs

@ -1,6 +1,10 @@
namespace StabilityMatrix.Core.Api; using Refit;
namespace StabilityMatrix.Core.Api;
public interface IApiFactory public interface IApiFactory
{ {
public T CreateRefitClient<T>(Uri baseAddress); public T CreateRefitClient<T>(Uri baseAddress);
public T CreateRefitClient<T>(Uri baseAddress, RefitSettings refitSettings);
} }

14
StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs

@ -0,0 +1,14 @@
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Core.Attributes;
/// <summary>
/// Options for <see cref="ComfyTypedNodeBase{TOutput}"/>
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
public class TypedNodeOptionsAttribute : Attribute
{
public string? Name { get; init; }
public string[]? RequiredExtensions { get; init; }
}

49
StabilityMatrix.Core/Converters/Json/NodeConnectionBaseJsonConverter.cs

@ -0,0 +1,49 @@
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Core.Converters.Json;
public class NodeConnectionBaseJsonConverter : JsonConverter<NodeConnectionBase>
{
/// <inheritdoc />
public override NodeConnectionBase Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options
)
{
// Read as Data array
reader.Read();
var data = new object[2];
reader.Read();
data[0] = reader.GetString() ?? throw new JsonException("Expected string for node name");
reader.Read();
data[1] = reader.GetInt32();
reader.Read();
if (Activator.CreateInstance(typeToConvert) is not NodeConnectionBase instance)
{
throw new JsonException($"Failed to create instance of {typeToConvert}");
}
var propertyInfo =
typeToConvert.GetProperty("Data", BindingFlags.Public | BindingFlags.Instance)
?? throw new JsonException($"Failed to get Data property of {typeToConvert}");
propertyInfo.SetValue(instance, data);
return instance;
}
/// <inheritdoc />
public override void Write(Utf8JsonWriter writer, NodeConnectionBase value, JsonSerializerOptions options)
{
// Write as Data array
writer.WriteStartArray();
writer.WriteStringValue(value.Data?[0] as string);
writer.WriteNumberValue((int)value.Data?[1]!);
writer.WriteEndArray();
}
}

63
StabilityMatrix.Core/Converters/Json/OneOfJsonConverter.cs

@ -0,0 +1,63 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using OneOf;
namespace StabilityMatrix.Core.Converters.Json;
public class OneOfJsonConverter<T1, T2> : JsonConverter<OneOf<T1, T2>>
{
/// <inheritdoc />
public override OneOf<T1, T2> Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options
)
{
// Not sure how else to do this without polymorphic type markers but that would not serialize into T1/T2
// So just try to deserialize T1, if it fails, try T2
Exception? t1Exception = null;
Exception? t2Exception = null;
try
{
if (JsonSerializer.Deserialize<T1>(ref reader, options) is { } t1)
{
return t1;
}
}
catch (JsonException e)
{
t1Exception = e;
}
try
{
if (JsonSerializer.Deserialize<T2>(ref reader, options) is { } t2)
{
return t2;
}
}
catch (JsonException e)
{
t2Exception = e;
}
throw new JsonException(
$"Failed to deserialize OneOf<{typeof(T1)}, {typeof(T2)}> as either {typeof(T1)} or {typeof(T2)}",
new AggregateException([t1Exception, t2Exception])
);
}
/// <inheritdoc />
public override void Write(Utf8JsonWriter writer, OneOf<T1, T2> value, JsonSerializerOptions options)
{
if (value.IsT0)
{
JsonSerializer.Serialize(writer, value.AsT0, options);
}
else
{
JsonSerializer.Serialize(writer, value.AsT1, options);
}
}
}

22
StabilityMatrix.Core/Helper/RemoteModels.cs

@ -132,8 +132,7 @@ public static class RemoteModels
} }
}; };
private static Uri ControlNetRoot { get; } = private static Uri ControlNetRoot { get; } = new("https://huggingface.co/lllyasviel/ControlNet/");
new("https://huggingface.co/lllyasviel/ControlNet/");
private static RemoteResource ControlNetCommon(string path, string sha256) private static RemoteResource ControlNetCommon(string path, string sha256)
{ {
@ -170,4 +169,23 @@ public static class RemoteModels
public static IReadOnlyList<HybridModelFile> ControlNetModels { get; } = public static IReadOnlyList<HybridModelFile> ControlNetModels { get; } =
ControlNets.Select(HybridModelFile.FromDownloadable).ToImmutableArray(); ControlNets.Select(HybridModelFile.FromDownloadable).ToImmutableArray();
private static IEnumerable<RemoteResource> PromptExpansions =>
[
new RemoteResource
{
Url = new Uri("https://cdn.lykos.ai/models/GPT-Prompt-Expansion-Fooocus-v2.zip"),
HashSha256 = "82e69311787c0bb6736389710d80c0a2b653ed9bbe6ea6e70c6b90820fe42d88",
InfoUrl = new Uri("https://huggingface.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2"),
Author = "lllyasviel, LykosAI",
LicenseType = "GPLv3",
LicenseUrl = new Uri("https://github.com/lllyasviel/Fooocus/blob/main/LICENSE"),
ContextType = SharedFolderType.PromptExpansion,
AutoExtractArchive = true,
ExtractRelativePath = "GPT-Prompt-Expansion-Fooocus-v2"
}
];
public static IEnumerable<HybridModelFile> PromptExpansionModels =>
PromptExpansions.Select(HybridModelFile.FromDownloadable);
} }

28
StabilityMatrix.Core/Inference/ComfyClient.cs

@ -6,10 +6,13 @@ using NLog;
using Polly.Contrib.WaitAndRetry; using Polly.Contrib.WaitAndRetry;
using Refit; using Refit;
using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Converters.Json;
using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.FileInterfaces;
using Websocket.Client; using Websocket.Client;
@ -26,8 +29,16 @@ public class ComfyClient : InferenceClientBase
private readonly IComfyApi comfyApi; private readonly IComfyApi comfyApi;
private bool isDisposed; private bool isDisposed;
private JsonSerializerOptions jsonSerializerOptions = private readonly JsonSerializerOptions jsonSerializerOptions =
new() { PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower, }; new()
{
PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower,
Converters =
{
new NodeConnectionBaseJsonConverter(),
new OneOfJsonConverter<string, StringNodeConnection>()
}
};
// ReSharper disable once MemberCanBePrivate.Global // ReSharper disable once MemberCanBePrivate.Global
public string ClientId { get; } = Guid.NewGuid().ToString(); public string ClientId { get; } = Guid.NewGuid().ToString();
@ -39,6 +50,11 @@ public class ComfyClient : InferenceClientBase
/// </summary> /// </summary>
public DirectoryPath? LocalServerPath { get; set; } public DirectoryPath? LocalServerPath { get; set; }
/// <summary>
/// If available, the local server package pair
/// </summary>
public PackagePair? LocalServerPackage { get; set; }
/// <summary> /// <summary>
/// Path to the "output" folder from LocalServerPath /// Path to the "output" folder from LocalServerPath
/// </summary> /// </summary>
@ -81,7 +97,13 @@ public class ComfyClient : InferenceClientBase
public ComfyClient(IApiFactory apiFactory, Uri baseAddress) public ComfyClient(IApiFactory apiFactory, Uri baseAddress)
{ {
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(baseAddress); comfyApi = apiFactory.CreateRefitClient<IComfyApi>(
baseAddress,
new RefitSettings
{
ContentSerializer = new SystemTextJsonContentSerializer(jsonSerializerOptions),
}
);
BaseAddress = baseAddress; BaseAddress = baseAddress;
// Setup websocket client // Setup websocket client

18
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnectionBase.cs

@ -1,12 +1,14 @@
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using System.Text.Json.Serialization;
using StabilityMatrix.Core.Converters.Json;
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
[JsonConverter(typeof(NodeConnectionBaseJsonConverter))]
public abstract class NodeConnectionBase public abstract class NodeConnectionBase
{ {
public object[]? Data { get; set; } /// <summary>
/// Array data for the connection.
// Implicit conversion to object[] /// [(string) Node Name, (int) Connection Index]
public static implicit operator object[](NodeConnectionBase nodeConnection) /// </summary>
{ public object[]? Data { get; init; }
return nodeConnection.Data ?? Array.Empty<object>();
}
} }

26
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnections.cs

@ -1,25 +1,27 @@
namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; namespace StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
public class LatentNodeConnection : NodeConnectionBase { } public class LatentNodeConnection : NodeConnectionBase;
public class VAENodeConnection : NodeConnectionBase { } public class VAENodeConnection : NodeConnectionBase;
public class ImageNodeConnection : NodeConnectionBase { } public class ImageNodeConnection : NodeConnectionBase;
public class ImageMaskConnection : NodeConnectionBase { } public class ImageMaskConnection : NodeConnectionBase;
public class UpscaleModelNodeConnection : NodeConnectionBase { } public class UpscaleModelNodeConnection : NodeConnectionBase;
public class ModelNodeConnection : NodeConnectionBase { } public class ModelNodeConnection : NodeConnectionBase;
public class ConditioningNodeConnection : NodeConnectionBase { } public class ConditioningNodeConnection : NodeConnectionBase;
public class ClipNodeConnection : NodeConnectionBase { } public class ClipNodeConnection : NodeConnectionBase;
public class ControlNetNodeConnection : NodeConnectionBase { } public class ControlNetNodeConnection : NodeConnectionBase;
public class ClipVisionNodeConnection : NodeConnectionBase { } public class ClipVisionNodeConnection : NodeConnectionBase;
public class SamplerNodeConnection : NodeConnectionBase { } public class SamplerNodeConnection : NodeConnectionBase;
public class SigmasNodeConnection : NodeConnectionBase { } public class SigmasNodeConnection : NodeConnectionBase;
public class StringNodeConnection : NodeConnectionBase;

37
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -1,8 +1,8 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Drawing; using System.Drawing;
using System.Runtime.Serialization; using OneOf;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
@ -14,6 +14,7 @@ namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
/// Builder functions for comfy nodes /// Builder functions for comfy nodes
/// </summary> /// </summary>
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[Localizable(false)]
public class ComfyNodeBuilder public class ComfyNodeBuilder
{ {
public NodeDictionary Nodes { get; } = new(); public NodeDictionary Nodes { get; } = new();
@ -258,20 +259,7 @@ public class ComfyNodeBuilder
public record CLIPTextEncode : ComfyTypedNodeBase<ConditioningNodeConnection> public record CLIPTextEncode : ComfyTypedNodeBase<ConditioningNodeConnection>
{ {
public required ClipNodeConnection Clip { get; init; } public required ClipNodeConnection Clip { get; init; }
public required string Text { get; init; } public required OneOf<string, StringNodeConnection> Text { get; init; }
}
public static NamedComfyNode<ConditioningNodeConnection> ClipTextEncode(
string name,
ClipNodeConnection clip,
string text
)
{
return new NamedComfyNode<ConditioningNodeConnection>(name)
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?> { ["clip"] = clip.Data, ["text"] = text }
};
} }
public record LoadImage : ComfyTypedNodeBase<ImageNodeConnection, ImageMaskConnection> public record LoadImage : ComfyTypedNodeBase<ImageNodeConnection, ImageMaskConnection>
@ -342,6 +330,18 @@ public class ComfyNodeBuilder
public required string Method { get; init; } public required string Method { get; init; }
} }
[TypedNodeOptions(
Name = "Inference_Core_PromptExpansion",
RequiredExtensions = ["https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes"]
)]
public record PromptExpansion : ComfyTypedNodeBase<StringNodeConnection>
{
public required string ModelName { get; init; }
public required OneOf<string, StringNodeConnection> Text { get; init; }
public required ulong Seed { get; init; }
public bool LogPrompt { get; init; }
}
public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae) public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
{ {
var name = GetUniqueName("VAEDecode"); var name = GetUniqueName("VAEDecode");
@ -801,6 +801,9 @@ public class ComfyNodeBuilder
public int BatchSize { get; set; } = 1; public int BatchSize { get; set; } = 1;
public int? BatchIndex { get; set; } public int? BatchIndex { get; set; }
public OneOf<string, StringNodeConnection> PositivePrompt { get; set; }
public OneOf<string, StringNodeConnection> NegativePrompt { get; set; }
public ClipNodeConnection? BaseClip { get; set; } public ClipNodeConnection? BaseClip { get; set; }
public ClipVisionNodeConnection? BaseClipVision { get; set; } public ClipVisionNodeConnection? BaseClipVision { get; set; }

25
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyTypedNodeBase.cs

@ -1,4 +1,5 @@
using System.Reflection; using System.ComponentModel;
using System.Reflection;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
@ -8,8 +9,28 @@ namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
public abstract record ComfyTypedNodeBase public abstract record ComfyTypedNodeBase
{ {
protected virtual string ClassType => GetType().Name; [Localizable(false)]
protected virtual string ClassType
{
get
{
var type = GetType();
// Use options name if available
if (type.GetCustomAttribute<TypedNodeOptionsAttribute>() is { } options)
{
if (!string.IsNullOrEmpty(options.Name))
{
return options.Name;
}
}
// Otherwise use class name
return type.Name;
}
}
[Localizable(false)]
[JsonIgnore] [JsonIgnore]
public required string Name { get; init; } public required string Name { get; init; }

33
StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs

@ -1,4 +1,9 @@
using StabilityMatrix.Core.Helper; using System.ComponentModel;
using System.Reflection;
using System.Text.Json.Serialization;
using OneOf;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes; namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes;
@ -10,10 +15,16 @@ public class NodeDictionary : Dictionary<string, ComfyNode>
/// </summary> /// </summary>
private readonly Dictionary<string, int> _baseNameIndex = new(); private readonly Dictionary<string, int> _baseNameIndex = new();
/// <summary>
/// When inserting TypedNodes, this holds a mapping of ClassType to required extensions
/// </summary>
[JsonIgnore]
public Dictionary<string, string[]> ClassTypeRequiredExtensions { get; } = new();
/// <summary> /// <summary>
/// Finds a unique node name given a base name, by appending _2, _3, etc. /// Finds a unique node name given a base name, by appending _2, _3, etc.
/// </summary> /// </summary>
public string GetUniqueName(string nameBase) public string GetUniqueName([Localizable(false)] string nameBase)
{ {
if (_baseNameIndex.TryGetValue(nameBase, out var index)) if (_baseNameIndex.TryGetValue(nameBase, out var index))
{ {
@ -43,7 +54,19 @@ public class NodeDictionary : Dictionary<string, ComfyNode>
public TTypedNode AddTypedNode<TTypedNode>(TTypedNode node) public TTypedNode AddTypedNode<TTypedNode>(TTypedNode node)
where TTypedNode : ComfyTypedNodeBase where TTypedNode : ComfyTypedNodeBase
{ {
Add(node.Name, node); var namedNode = (NamedComfyNode)node;
Add(node.Name, namedNode);
// Check statically annotated stuff for TypedNodeOptionsAttribute
if (node.GetType().GetCustomAttribute<TypedNodeOptionsAttribute>() is { } options)
{
if (options.RequiredExtensions != null)
{
ClassTypeRequiredExtensions[namedNode.ClassType] = options.RequiredExtensions;
}
}
return node; return node;
} }
@ -62,6 +85,10 @@ public class NodeDictionary : Dictionary<string, ComfyNode>
{ {
node.Inputs[key] = connection.Data; node.Inputs[key] = connection.Data;
} }
else if (input is IOneOf { Value: NodeConnectionBase oneOfConnection })
{
node.Inputs[key] = oneOfConnection.Data;
}
} }
} }
} }

1
StabilityMatrix.Core/Models/HybridModelFile.cs

@ -71,6 +71,7 @@ public record HybridModelFile
if ( if (
!fileName.Equals("diffusion_pytorch_model", StringComparison.OrdinalIgnoreCase) !fileName.Equals("diffusion_pytorch_model", StringComparison.OrdinalIgnoreCase)
&& !fileName.Equals("pytorch_model", StringComparison.OrdinalIgnoreCase)
&& !fileName.Equals("ip_adapter", StringComparison.OrdinalIgnoreCase) && !fileName.Equals("ip_adapter", StringComparison.OrdinalIgnoreCase)
) )
{ {

10
StabilityMatrix.Core/Models/Packages/ComfyUI.cs

@ -68,6 +68,7 @@ public class ComfyUI(
[SharedFolderType.InvokeIpAdapters15] = new[] { "models/ipadapter/sd15" }, [SharedFolderType.InvokeIpAdapters15] = new[] { "models/ipadapter/sd15" },
[SharedFolderType.InvokeIpAdaptersXl] = new[] { "models/ipadapter/sdxl" }, [SharedFolderType.InvokeIpAdaptersXl] = new[] { "models/ipadapter/sdxl" },
[SharedFolderType.T2IAdapter] = new[] { "models/controlnet/T2IAdapter" }, [SharedFolderType.T2IAdapter] = new[] { "models/controlnet/T2IAdapter" },
[SharedFolderType.PromptExpansion] = new[] { "models/prompt_expansion" }
}; };
public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders => public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders =>
@ -373,6 +374,7 @@ public class ComfyUI(
Path.Combine(modelsDir, "InvokeIpAdapters15"), Path.Combine(modelsDir, "InvokeIpAdapters15"),
Path.Combine(modelsDir, "InvokeIpAdaptersXl") Path.Combine(modelsDir, "InvokeIpAdaptersXl")
); );
nodeValue.Children["prompt_expansion"] = Path.Combine(modelsDir, "PromptExpansion");
} }
else else
{ {
@ -410,7 +412,8 @@ public class ComfyUI(
Path.Combine(modelsDir, "InvokeIpAdapters15"), Path.Combine(modelsDir, "InvokeIpAdapters15"),
Path.Combine(modelsDir, "InvokeIpAdaptersXl") Path.Combine(modelsDir, "InvokeIpAdaptersXl")
) )
} },
{ "prompt_expansion", Path.Combine(modelsDir, "PromptExpansion") }
} }
); );
} }
@ -476,9 +479,8 @@ public class ComfyUI(
public override IEnumerable<ExtensionManifest> DefaultManifests => public override IEnumerable<ExtensionManifest> DefaultManifests =>
[ [
new ExtensionManifest( "https://cdn.jsdelivr.net/gh/ltdrdata/ComfyUI-Manager/custom-node-list.json",
new Uri("https://cdn.jsdelivr.net/gh/ltdrdata/ComfyUI-Manager/custom-node-list.json") "https://cdn.jsdelivr.net/gh/LykosAI/ComfyUI-Extensions-Index/custom-node-list.json"
)
]; ];
public override async Task<IEnumerable<PackageExtension>> GetManifestExtensionsAsync( public override async Task<IEnumerable<PackageExtension>> GetManifestExtensionsAsync(

5
StabilityMatrix.Core/Models/Packages/Extensions/ExtensionManifest.cs

@ -1,3 +1,6 @@
namespace StabilityMatrix.Core.Models.Packages.Extensions; namespace StabilityMatrix.Core.Models.Packages.Extensions;
public record ExtensionManifest(Uri Uri); public record ExtensionManifest(Uri Uri)
{
public static implicit operator ExtensionManifest(string uri) => new(new Uri(uri, UriKind.Absolute));
}

78
StabilityMatrix.Core/Models/Packages/Extensions/GitPackageExtensionManager.cs

@ -1,4 +1,6 @@
using System.Text.RegularExpressions;
using KGySoft.CoreLibraries; using KGySoft.CoreLibraries;
using Microsoft.Extensions.Caching.Memory;
using NLog; using NLog;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
@ -8,11 +10,14 @@ using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Models.Packages.Extensions; namespace StabilityMatrix.Core.Models.Packages.Extensions;
public abstract class GitPackageExtensionManager(IPrerequisiteHelper prerequisiteHelper) public abstract partial class GitPackageExtensionManager(IPrerequisiteHelper prerequisiteHelper)
: IPackageExtensionManager : IPackageExtensionManager
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
// Cache checks of installed extensions
private readonly MemoryCache installedExtensionsCache = new(new MemoryCacheOptions());
public abstract string RelativeInstallDirectory { get; } public abstract string RelativeInstallDirectory { get; }
public virtual IEnumerable<ExtensionManifest> DefaultManifests { get; } = public virtual IEnumerable<ExtensionManifest> DefaultManifests { get; } =
@ -128,6 +133,74 @@ public abstract class GitPackageExtensionManager(IPrerequisiteHelper prerequisit
return extensions; return extensions;
} }
/// <summary>
/// Like <see cref="GetInstalledExtensionsAsync"/>, but does not check git version and repository url.
/// </summary>
public virtual async Task<IEnumerable<InstalledPackageExtension>> GetInstalledExtensionsLiteAsync(
InstalledPackage installedPackage,
CancellationToken cancellationToken = default
)
{
if (installedPackage.FullPath is not { } packagePath)
{
return Enumerable.Empty<InstalledPackageExtension>();
}
var extensions = new List<InstalledPackageExtension>();
// Search for installed extensions in the package's index directories.
foreach (
var indexDirectory in IndexRelativeDirectories.Select(
path => new DirectoryPath(packagePath, path)
)
)
{
cancellationToken.ThrowIfCancellationRequested();
// Skip directory if not exists
if (!indexDirectory.Exists)
{
continue;
}
// Check subdirectories of the index directory
foreach (var subDirectory in indexDirectory.EnumerateDirectories())
{
cancellationToken.ThrowIfCancellationRequested();
// Skip if not valid git repository
if (!subDirectory.JoinDir(".git").Exists)
{
continue;
}
// Get remote url with manual parsing
string? remoteUrl = null;
var gitConfigPath = subDirectory.JoinDir(".git").JoinFile("config");
if (
gitConfigPath.Exists
&& await gitConfigPath.ReadAllTextAsync(cancellationToken).ConfigureAwait(false)
is { } gitConfigText
)
{
var pattern = GitConfigRemoteOriginUrlRegex();
var match = pattern.Match(gitConfigText);
if (match.Success)
{
remoteUrl = match.Groups[1].Value;
}
}
extensions.Add(
new InstalledPackageExtension { Paths = [subDirectory], GitRepositoryUrl = remoteUrl }
);
}
}
return extensions;
}
/// <inheritdoc /> /// <inheritdoc />
public virtual async Task InstallExtensionAsync( public virtual async Task InstallExtensionAsync(
PackageExtension extension, PackageExtension extension,
@ -242,4 +315,7 @@ public abstract class GitPackageExtensionManager(IPrerequisiteHelper prerequisit
progress?.Report(new ProgressReport(1f, message: "Uninstalled extension")); progress?.Report(new ProgressReport(1f, message: "Uninstalled extension"));
} }
[GeneratedRegex("""\[remote "origin"\][\s\S]*?url\s*=\s*(.+)""")]
private static partial Regex GitConfigRemoteOriginUrlRegex();
} }

29
StabilityMatrix.Core/Models/Packages/Extensions/IPackageExtensionManager.cs

@ -52,6 +52,35 @@ public interface IPackageExtensionManager
return extensions; return extensions;
} }
/// <summary>
/// Get unique extensions from all provided manifests. As a mapping of their reference.
/// </summary>
async Task<IDictionary<string, PackageExtension>> GetManifestExtensionsMapAsync(
IEnumerable<ExtensionManifest> manifests,
CancellationToken cancellationToken = default
)
{
var result = new Dictionary<string, PackageExtension>();
foreach (
var extension in await GetManifestExtensionsAsync(manifests, cancellationToken)
.ConfigureAwait(false)
)
{
cancellationToken.ThrowIfCancellationRequested();
var key = extension.Reference.ToString();
if (!result.TryAdd(key, extension))
{
// Replace
result[key] = extension;
}
}
return result;
}
/// <summary> /// <summary>
/// Get all installed extensions for the provided package. /// Get all installed extensions for the provided package.
/// </summary> /// </summary>

2
StabilityMatrix.Core/Models/SharedFolderType.cs

@ -39,4 +39,6 @@ public enum SharedFolderType
InvokeIpAdapters15 = 1 << 24, InvokeIpAdapters15 = 1 << 24,
InvokeIpAdaptersXl = 1 << 25, InvokeIpAdaptersXl = 1 << 25,
InvokeClipVision = 1 << 26, InvokeClipVision = 1 << 26,
PromptExpansion = 1 << 30
} }

2
StabilityMatrix.Core/Models/TrackedDownload.cs

@ -124,6 +124,8 @@ public class TrackedDownload
{ {
var progress = new Progress<ProgressReport>(OnProgressUpdate); var progress = new Progress<ProgressReport>(OnProgressUpdate);
DownloadDirectory.Create();
await downloadService! await downloadService!
.ResumeDownloadToFileAsync( .ResumeDownloadToFileAsync(
SourceUrl.ToString(), SourceUrl.ToString(),

Loading…
Cancel
Save