Browse Source

Merge pull request #394 from ionite34/fix-binding

pull/324/head
Ionite 12 months ago committed by GitHub
parent
commit
106d6c0eaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      CHANGELOG.md
  2. 105
      StabilityMatrix.Avalonia/App.axaml.cs
  3. 4
      StabilityMatrix.Avalonia/Controls/ControlNetCard.axaml
  4. 6
      StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml
  5. 37
      StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs
  6. 2
      StabilityMatrix.Avalonia/MarkupExtensions/ShowDisabledTooltipExtension.cs
  7. 42
      StabilityMatrix.Avalonia/MarkupExtensions/TernaryExtension.cs
  8. 138
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  9. 14
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  10. 3
      StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs
  11. 51
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
  12. 23
      StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs
  13. 15
      StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
  14. 11
      StabilityMatrix.Core/Models/Api/Comfy/ComfySampler.cs
  15. 11
      StabilityMatrix.Core/Models/Api/Comfy/ComfyScheduler.cs
  16. 4
      StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnections.cs
  17. 227
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

8
CHANGELOG.md

@ -13,6 +13,14 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
- Use the plus button to add new steps (Hires Fix, Upscaler, and Save Image are currently available), and the edit button to enable removing or dragging steps to reorder them. This enables multi-pass Hires Fix, mixing different upscalers, and saving intermediate images at any point in the pipeline. - Use the plus button to add new steps (Hires Fix, Upscaler, and Save Image are currently available), and the edit button to enable removing or dragging steps to reorder them. This enables multi-pass Hires Fix, mixing different upscalers, and saving intermediate images at any point in the pipeline.
- Added Sampler addons - Added Sampler addons
- Addons usually affect guidance like ControlNet, T2I, FreeU, and other addons to come. They apply to the individual sampler, so you can mix and match different ControlNets for Base and Hires Fix, or use the current output from a previous sampler as ControlNet guidance image for HighRes passes. - Addons usually affect guidance like ControlNet, T2I, FreeU, and other addons to come. They apply to the individual sampler, so you can mix and match different ControlNets for Base and Hires Fix, or use the current output from a previous sampler as ControlNet guidance image for HighRes passes.
- Added SD Turbo Scheduler
- Added display names for new samplers ("Heun++ 2", "DDPM", "LCM")
#### Model Browser
- Added additional base model filter options ("SD 1.5 LCM", "SDXL 1.0 LCM", "SDXL Turbo", "Other")
### Changed
#### Inference
- Selected images (i.e. Image2Image, Upscale, ControlNet) will now save their source paths saved and restored on load. If the image is moved or deleted, the selection will show as missing and can be reselected
- Project files (.smproj) have been updated to v3, existing projects will be upgraded on load and will no longer be compatible with older versions of Stability Matrix
### Fixed ### Fixed
- Fixed Refiner model enabled state not saving to Inference project files - Fixed Refiner model enabled state not saving to Inference project files

105
StabilityMatrix.Avalonia/App.axaml.cs

@ -10,6 +10,7 @@ using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia; using Avalonia;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Controls.ApplicationLifetimes; using Avalonia.Controls.ApplicationLifetimes;
@ -69,6 +70,8 @@ namespace StabilityMatrix.Avalonia;
public sealed class App : Application public sealed class App : Application
{ {
private static bool isAsyncDisposeComplete;
[NotNull] [NotNull]
public static IServiceProvider? Services { get; private set; } public static IServiceProvider? Services { get; private set; }
@ -220,23 +223,7 @@ public sealed class App : Application
mainWindow.WindowStartupLocation = WindowStartupLocation.CenterScreen; mainWindow.WindowStartupLocation = WindowStartupLocation.CenterScreen;
} }
mainWindow.Closing += (_, _) => mainWindow.Closing += OnMainWindowClosing;
{
var validWindowPosition = mainWindow.Screens.All.Any(screen => screen.Bounds.Contains(mainWindow.Position));
settingsManager.Transaction(
s =>
{
s.WindowSettings = new WindowSettings(
mainWindow.Width,
mainWindow.Height,
validWindowPosition ? mainWindow.Position.X : 0,
validWindowPosition ? mainWindow.Position.Y : 0
);
},
ignoreMissingLibraryDir: true
);
};
mainWindow.Closed += (_, _) => Shutdown(); mainWindow.Closed += (_, _) => Shutdown();
mainWindow.SetDefaultFonts(); mainWindow.SetDefaultFonts();
@ -596,6 +583,87 @@ public sealed class App : Application
} }
} }
/// <summary>
/// Handle shutdown requests (happens before <see cref="OnExit"/>)
/// </summary>
private static void OnMainWindowClosing(object? sender, WindowClosingEventArgs e)
{
if (e.Cancel)
return;
var mainWindow = (MainWindow)sender!;
// Show confirmation if package running
var launchPageViewModel = Services.GetRequiredService<LaunchPageViewModel>();
launchPageViewModel.OnMainWindowClosing(e);
if (e.Cancel)
return;
// Check if we need to dispose IAsyncDisposables
if (
!isAsyncDisposeComplete
&& Services.GetServices<IAsyncDisposable>().ToList() is { Count: > 0 } asyncDisposables
)
{
// Cancel shutdown for now
e.Cancel = true;
isAsyncDisposeComplete = true;
Debug.WriteLine("OnShutdownRequested Canceled: Disposing IAsyncDisposables");
Task.Run(async () =>
{
foreach (var disposable in asyncDisposables)
{
Debug.WriteLine($"Disposing IAsyncDisposable ({disposable.GetType().Name})");
try
{
await disposable.DisposeAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
Debug.Fail(ex.ToString());
}
}
})
.ContinueWith(_ =>
{
// Shutdown again
Dispatcher.UIThread.Invoke(() => Shutdown());
})
.SafeFireAndForget();
return;
}
OnMainWindowClosingTerminal(mainWindow);
}
/// <summary>
/// Called at the end of <see cref="OnMainWindowClosing"/> before the main window is closed.
/// </summary>
private static void OnMainWindowClosingTerminal(Window sender)
{
var settingsManager = Services.GetRequiredService<ISettingsManager>();
// Save window position
var validWindowPosition = sender.Screens.All.Any(screen => screen.Bounds.Contains(sender.Position));
settingsManager.Transaction(
s =>
{
s.WindowSettings = new WindowSettings(
sender.Width,
sender.Height,
validWindowPosition ? sender.Position.X : 0,
validWindowPosition ? sender.Position.Y : 0
);
},
ignoreMissingLibraryDir: true
);
}
private static void OnExit(object? sender, ControlledApplicationLifetimeExitEventArgs args) private static void OnExit(object? sender, ControlledApplicationLifetimeExitEventArgs args)
{ {
Debug.WriteLine("Start OnExit"); Debug.WriteLine("Start OnExit");
@ -610,7 +678,8 @@ public sealed class App : Application
} }
Debug.WriteLine("Start OnExit: Disposing services"); Debug.WriteLine("Start OnExit: Disposing services");
// Dispose all services
// Dispose IDisposable services
foreach (var disposable in Services.GetServices<IDisposable>()) foreach (var disposable in Services.GetServices<IDisposable>())
{ {
Debug.WriteLine($"Disposing {disposable.GetType().Name}"); Debug.WriteLine($"Disposing {disposable.GetType().Name}");

4
StabilityMatrix.Avalonia/Controls/ControlNetCard.axaml

@ -2,7 +2,6 @@
xmlns="https://github.com/avaloniaui" xmlns="https://github.com/avaloniaui"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:controls="using:StabilityMatrix.Avalonia.Controls" xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
xmlns:extensions="clr-namespace:StabilityMatrix.Avalonia.Controls.Extensions"
xmlns:fluentIcons="clr-namespace:FluentIcons.FluentAvalonia;assembly=FluentIcons.FluentAvalonia" xmlns:fluentIcons="clr-namespace:FluentIcons.FluentAvalonia;assembly=FluentIcons.FluentAvalonia"
xmlns:input="clr-namespace:FluentAvalonia.UI.Input;assembly=FluentAvalonia" xmlns:input="clr-namespace:FluentAvalonia.UI.Input;assembly=FluentAvalonia"
xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages" xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
@ -11,6 +10,7 @@
xmlns:sg="clr-namespace:SpacedGridControl.Avalonia;assembly=SpacedGridControl.Avalonia" xmlns:sg="clr-namespace:SpacedGridControl.Avalonia;assembly=SpacedGridControl.Avalonia"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia" xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference" xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
xmlns:markupExtensions="clr-namespace:StabilityMatrix.Avalonia.MarkupExtensions"
x:DataType="vmInference:ControlNetCardViewModel"> x:DataType="vmInference:ControlNetCardViewModel">
<Design.PreviewWith> <Design.PreviewWith>
<StackPanel Width="400" Height="500"> <StackPanel Width="400" Height="500">
@ -53,7 +53,7 @@
Grid.Column="1" Grid.Column="1"
Margin="0,0,0,4" Margin="0,0,0,4"
HorizontalAlignment="Stretch" HorizontalAlignment="Stretch"
extensions:ShowDisabledTooltipExtension.ShowOnDisabled="True" markupExtensions:ShowDisabledTooltipExtension.ShowOnDisabled="True"
Header="{x:Static lang:Resources.Label_Preprocessor}" Header="{x:Static lang:Resources.Label_Preprocessor}"
IsEnabled="False" IsEnabled="False"
Theme="{StaticResource FAComboBoxHybridModelTheme}" Theme="{StaticResource FAComboBoxHybridModelTheme}"

6
StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml

@ -41,10 +41,10 @@
<Panel> <Panel>
<!-- Image --> <!-- Image -->
<controls:BetterAdvancedImage <controls:BetterAdvancedImage
x:Name="PART_BetterAdvancedImage"
VerticalAlignment="Stretch" VerticalAlignment="Stretch"
VerticalContentAlignment="Stretch" VerticalContentAlignment="Stretch"
CornerRadius="4" CornerRadius="4"
CurrentImage="{Binding CurrentBitmap, Mode=TwoWay}"
IsVisible="{Binding !IsSelectionAvailable}" IsVisible="{Binding !IsSelectionAvailable}"
RenderOptions.BitmapInterpolationMode="HighQuality" RenderOptions.BitmapInterpolationMode="HighQuality"
Source="{Binding ImageSource}" Source="{Binding ImageSource}"
@ -53,8 +53,8 @@
<!-- Missing image --> <!-- Missing image -->
<Border <Border
BorderThickness="3"
BorderBrush="{StaticResource ThemeCoralRedColor}" BorderBrush="{StaticResource ThemeCoralRedColor}"
BorderThickness="3"
BoxShadow="inset 1.2 0 20 1.8 #66000000" BoxShadow="inset 1.2 0 20 1.8 #66000000"
CornerRadius="4" CornerRadius="4"
IsVisible="{Binding IsImageFileNotFound}"> IsVisible="{Binding IsImageFileNotFound}">
@ -75,8 +75,8 @@
TextWrapping="WrapWithOverflow" /> TextWrapping="WrapWithOverflow" />
<SelectableTextBlock <SelectableTextBlock
Grid.Row="2" Grid.Row="2"
FontSize="10"
Margin="0,4,0,0" Margin="0,4,0,0"
FontSize="10"
Foreground="{DynamicResource TextFillColorTertiaryBrush}" Foreground="{DynamicResource TextFillColorTertiaryBrush}"
Text="{Binding NotFoundImagePath}" Text="{Binding NotFoundImagePath}"
TextAlignment="Center" TextAlignment="Center"

37
StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs

@ -1,6 +1,39 @@
using StabilityMatrix.Core.Attributes; using System;
using System.Drawing;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using DynamicData.Binding;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls; namespace StabilityMatrix.Avalonia.Controls;
[Transient] [Transient]
public class SelectImageCard : DropTargetTemplatedControlBase { } public class SelectImageCard : DropTargetTemplatedControlBase
{
/// <inheritdoc />
protected override void OnApplyTemplate(TemplateAppliedEventArgs e)
{
base.OnApplyTemplate(e);
if (DataContext is not SelectImageCardViewModel vm)
return;
if (e.NameScope.Find<BetterAdvancedImage>("PART_BetterAdvancedImage") is not { } image)
return;
image
.WhenPropertyChanged(x => x.CurrentImage)
.Subscribe(propertyValue =>
{
if (propertyValue.Value?.Size is { } size)
{
vm.CurrentBitmapSize = new Size(Convert.ToInt32(size.Width), Convert.ToInt32(size.Height));
}
else
{
vm.CurrentBitmapSize = Size.Empty;
}
});
}
}

2
StabilityMatrix.Avalonia/Controls/Extensions/ShowDisabledTooltipExtension.cs → StabilityMatrix.Avalonia/MarkupExtensions/ShowDisabledTooltipExtension.cs

@ -6,7 +6,7 @@ using Avalonia.Interactivity;
using Avalonia.VisualTree; using Avalonia.VisualTree;
using JetBrains.Annotations; using JetBrains.Annotations;
namespace StabilityMatrix.Avalonia.Controls.Extensions; namespace StabilityMatrix.Avalonia.MarkupExtensions;
/// <summary> /// <summary>
/// Show tooltip on Controls with IsEffectivelyEnabled = false /// Show tooltip on Controls with IsEffectivelyEnabled = false

42
StabilityMatrix.Avalonia/MarkupExtensions/TernaryExtension.cs

@ -0,0 +1,42 @@
using System;
using System.Globalization;
using Avalonia.Data;
using Avalonia.Data.Converters;
using Avalonia.Markup.Xaml;
using Avalonia.Markup.Xaml.MarkupExtensions;
namespace StabilityMatrix.Avalonia.MarkupExtensions;
/// <summary>
/// https://github.com/AvaloniaUI/Avalonia/discussions/7408
/// </summary>
/// <example>
/// <code>{e:Ternary SomeProperty, True=1, False=0}</code>
/// </example>
public class TernaryExtension : MarkupExtension
{
public string Path { get; set; }
public Type Type { get; set; }
public object? True { get; set; }
public object? False { get; set; }
public override object ProvideValue(IServiceProvider serviceProvider)
{
var cultureInfo = CultureInfo.GetCultureInfo("en-US");
var binding = new ReflectionBindingExtension(Path)
{
Mode = BindingMode.OneWay,
Converter = new FuncValueConverter<bool, object?>(
isTrue =>
isTrue
? Convert.ChangeType(True, Type, cultureInfo.NumberFormat)
: Convert.ChangeType(False, Type, cultureInfo.NumberFormat)
)
};
return binding.ProvideValue(serviceProvider);
}
}

138
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -56,8 +56,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
private readonly SourceCache<HybridModelFile, string> modelsSource = new(p => p.GetId()); private readonly SourceCache<HybridModelFile, string> modelsSource = new(p => p.GetId());
public IObservableCollection<HybridModelFile> Models { get; } = public IObservableCollection<HybridModelFile> Models { get; } = new ObservableCollectionExtended<HybridModelFile>();
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<HybridModelFile, string> vaeModelsSource = new(p => p.GetId()); private readonly SourceCache<HybridModelFile, string> vaeModelsSource = new(p => p.GetId());
@ -66,29 +65,24 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
public IObservableCollection<HybridModelFile> VaeModels { get; } = public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>(); new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<HybridModelFile, string> controlNetModelsSource = private readonly SourceCache<HybridModelFile, string> controlNetModelsSource = new(p => p.GetId());
new(p => p.GetId());
private readonly SourceCache<HybridModelFile, string> downloadableControlNetModelsSource = private readonly SourceCache<HybridModelFile, string> downloadableControlNetModelsSource = new(p => p.GetId());
new(p => p.GetId());
public IObservableCollection<HybridModelFile> ControlNetModels { get; } = public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
new ObservableCollectionExtended<HybridModelFile>(); new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name); private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name);
public IObservableCollection<ComfySampler> Samplers { get; } = public IObservableCollection<ComfySampler> Samplers { get; } = new ObservableCollectionExtended<ComfySampler>();
new ObservableCollectionExtended<ComfySampler>();
private readonly SourceCache<ComfyUpscaler, string> modelUpscalersSource = new(p => p.Name); private readonly SourceCache<ComfyUpscaler, string> modelUpscalersSource = new(p => p.Name);
private readonly SourceCache<ComfyUpscaler, string> latentUpscalersSource = new(p => p.Name); private readonly SourceCache<ComfyUpscaler, string> latentUpscalersSource = new(p => p.Name);
private readonly SourceCache<ComfyUpscaler, string> downloadableUpscalersSource = private readonly SourceCache<ComfyUpscaler, string> downloadableUpscalersSource = new(p => p.Name);
new(p => p.Name);
public IObservableCollection<ComfyUpscaler> Upscalers { get; } = public IObservableCollection<ComfyUpscaler> Upscalers { get; } = new ObservableCollectionExtended<ComfyUpscaler>();
new ObservableCollectionExtended<ComfyUpscaler>();
private readonly SourceCache<ComfyScheduler, string> schedulersSource = new(p => p.Name); private readonly SourceCache<ComfyScheduler, string> schedulersSource = new(p => p.Name);
@ -111,11 +105,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
modelsSource modelsSource
.Connect() .Connect()
.SortBy( .SortBy(f => f.ShortDisplayName, SortDirection.Ascending, SortOptimisations.ComparesImmutableValuesOnly)
f => f.ShortDisplayName,
SortDirection.Ascending,
SortOptimisations.ComparesImmutableValuesOnly
)
.DeferUntilLoaded() .DeferUntilLoaded()
.Bind(Models) .Bind(Models)
.Subscribe(); .Subscribe();
@ -124,9 +114,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Connect() .Connect()
.Or(downloadableControlNetModelsSource.Connect()) .Or(downloadableControlNetModelsSource.Connect())
.Sort( .Sort(
SortExpressionComparer<HybridModelFile> SortExpressionComparer<HybridModelFile>.Ascending(f => f.Type).ThenByAscending(f => f.ShortDisplayName)
.Ascending(f => f.Type)
.ThenByAscending(f => f.ShortDisplayName)
) )
.DeferUntilLoaded() .DeferUntilLoaded()
.Bind(ControlNetModels) .Bind(ControlNetModels)
@ -142,11 +130,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Connect() .Connect()
.Or(modelUpscalersSource.Connect()) .Or(modelUpscalersSource.Connect())
.Or(downloadableUpscalersSource.Connect()) .Or(downloadableUpscalersSource.Connect())
.Sort( .Sort(SortExpressionComparer<ComfyUpscaler>.Ascending(f => f.Type).ThenByAscending(f => f.Name))
SortExpressionComparer<ComfyUpscaler>
.Ascending(f => f.Type)
.ThenByAscending(f => f.Name)
)
.Bind(Upscalers) .Bind(Upscalers)
.Subscribe(); .Subscribe();
@ -169,9 +153,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
if (IsConnected) if (IsConnected)
{ {
LoadSharedPropertiesAsync() LoadSharedPropertiesAsync()
.SafeFireAndForget( .SafeFireAndForget(onException: ex => logger.LogError(ex, "Error loading shared properties"));
onException: ex => logger.LogError(ex, "Error loading shared properties")
);
} }
}; };
} }
@ -190,17 +172,11 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
// Get model names // Get model names
if (await Client.GetModelNamesAsync() is { } modelNames) if (await Client.GetModelNamesAsync() is { } modelNames)
{ {
modelsSource.EditDiff( modelsSource.EditDiff(modelNames.Select(HybridModelFile.FromRemote), HybridModelFile.Comparer);
modelNames.Select(HybridModelFile.FromRemote),
HybridModelFile.Comparer
);
} }
// Get control net model names // Get control net model names
if ( if (await Client.GetNodeOptionNamesAsync("ControlNetLoader", "control_net_name") is { } controlNetModelNames)
await Client.GetNodeOptionNamesAsync("ControlNetLoader", "control_net_name") is
{ } controlNetModelNames
)
{ {
controlNetModelsSource.EditDiff( controlNetModelsSource.EditDiff(
controlNetModelNames.Select(HybridModelFile.FromRemote), controlNetModelNames.Select(HybridModelFile.FromRemote),
@ -211,19 +187,13 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
// Fetch sampler names from KSampler node // Fetch sampler names from KSampler node
if (await Client.GetSamplerNamesAsync() is { } samplerNames) if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{ {
samplersSource.EditDiff( samplersSource.EditDiff(samplerNames.Select(name => new ComfySampler(name)), ComfySampler.Comparer);
samplerNames.Select(name => new ComfySampler(name)),
ComfySampler.Comparer
);
} }
// Upscalers is latent and esrgan combined // Upscalers is latent and esrgan combined
// Add latent upscale methods from LatentUpscale node // Add latent upscale methods from LatentUpscale node
if ( if (await Client.GetNodeOptionNamesAsync("LatentUpscale", "upscale_method") is { } latentUpscalerNames)
await Client.GetNodeOptionNamesAsync("LatentUpscale", "upscale_method") is
{ } latentUpscalerNames
)
{ {
latentUpscalersSource.EditDiff( latentUpscalersSource.EditDiff(
latentUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)), latentUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)),
@ -234,10 +204,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
} }
// Add Model upscale methods // Add Model upscale methods
if ( if (await Client.GetNodeOptionNamesAsync("UpscaleModelLoader", "model_name") is { } modelUpscalerNames)
await Client.GetNodeOptionNamesAsync("UpscaleModelLoader", "model_name") is
{ } modelUpscalerNames
)
{ {
modelUpscalersSource.EditDiff( modelUpscalersSource.EditDiff(
modelUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)), modelUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)),
@ -249,10 +216,12 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
// Add scheduler names from Scheduler node // Add scheduler names from Scheduler node
if (await Client.GetNodeOptionNamesAsync("KSampler", "scheduler") is { } schedulerNames) if (await Client.GetNodeOptionNamesAsync("KSampler", "scheduler") is { } schedulerNames)
{ {
schedulersSource.EditDiff( schedulersSource.Edit(updater =>
schedulerNames.Select(s => new ComfyScheduler(s)), {
ComfyScheduler.Comparer updater.AddOrUpdate(
); schedulerNames.Where(n => !schedulersSource.Keys.Contains(n)).Select(s => new ComfyScheduler(s))
);
});
logger.LogTrace("Loaded scheduler methods: {@Schedulers}", schedulerNames); logger.LogTrace("Loaded scheduler methods: {@Schedulers}", schedulerNames);
} }
} }
@ -264,34 +233,25 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
{ {
// Load local models // Load local models
modelsSource.EditDiff( modelsSource.EditDiff(
modelIndexService modelIndexService.GetFromModelIndex(SharedFolderType.StableDiffusion).Select(HybridModelFile.FromLocal),
.GetFromModelIndex(SharedFolderType.StableDiffusion)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer HybridModelFile.Comparer
); );
// Load local control net models // Load local control net models
controlNetModelsSource.EditDiff( controlNetModelsSource.EditDiff(
modelIndexService modelIndexService.GetFromModelIndex(SharedFolderType.ControlNet).Select(HybridModelFile.FromLocal),
.GetFromModelIndex(SharedFolderType.ControlNet)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer HybridModelFile.Comparer
); );
// Downloadable ControlNet models // Downloadable ControlNet models
var downloadableControlNets = RemoteModels.ControlNetModels.Where( var downloadableControlNets = RemoteModels
u => !modelUpscalersSource.Lookup(u.GetId()).HasValue .ControlNetModels
); .Where(u => !modelUpscalersSource.Lookup(u.GetId()).HasValue);
downloadableControlNetModelsSource.EditDiff( downloadableControlNetModelsSource.EditDiff(downloadableControlNets, HybridModelFile.Comparer);
downloadableControlNets,
HybridModelFile.Comparer
);
// Load local VAE models // Load local VAE models
vaeModelsSource.EditDiff( vaeModelsSource.EditDiff(
modelIndexService modelIndexService.GetFromModelIndex(SharedFolderType.VAE).Select(HybridModelFile.FromLocal),
.GetFromModelIndex(SharedFolderType.VAE)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer HybridModelFile.Comparer
); );
@ -304,25 +264,20 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
// Load Upscalers // Load Upscalers
modelUpscalersSource.EditDiff( modelUpscalersSource.EditDiff(
modelIndexService modelIndexService
.GetFromModelIndex( .GetFromModelIndex(SharedFolderType.ESRGAN | SharedFolderType.RealESRGAN | SharedFolderType.SwinIR)
SharedFolderType.ESRGAN | SharedFolderType.RealESRGAN | SharedFolderType.SwinIR
)
.Select(m => new ComfyUpscaler(m.FileName, ComfyUpscalerType.ESRGAN)), .Select(m => new ComfyUpscaler(m.FileName, ComfyUpscalerType.ESRGAN)),
ComfyUpscaler.Comparer ComfyUpscaler.Comparer
); );
// Remote upscalers // Remote upscalers
var remoteUpscalers = ComfyUpscaler.DefaultDownloadableModels.Where( var remoteUpscalers = ComfyUpscaler
u => !modelUpscalersSource.Lookup(u.Name).HasValue .DefaultDownloadableModels
); .Where(u => !modelUpscalersSource.Lookup(u.Name).HasValue);
downloadableUpscalersSource.EditDiff(remoteUpscalers, ComfyUpscaler.Comparer); downloadableUpscalersSource.EditDiff(remoteUpscalers, ComfyUpscaler.Comparer);
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task UploadInputImageAsync( public async Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default)
ImageSource image,
CancellationToken cancellationToken = default
)
{ {
EnsureConnected(); EnsureConnected();
@ -338,10 +293,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task CopyImageToInputAsync( public async Task CopyImageToInputAsync(FilePath imageFile, CancellationToken cancellationToken = default)
FilePath imageFile,
CancellationToken cancellationToken = default
)
{ {
if (!IsConnected) if (!IsConnected)
return; return;
@ -370,10 +322,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task WriteImageToInputAsync( public async Task WriteImageToInputAsync(ImageSource imageSource, CancellationToken cancellationToken = default)
ImageSource imageSource,
CancellationToken cancellationToken = default
)
{ {
if (!IsConnected) if (!IsConnected)
return; return;
@ -436,16 +385,14 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
// For locally installed packages only // For locally installed packages only
// Delete ./output/Inference // Delete ./output/Inference
var legacyInferenceLinkDir = new DirectoryPath( var legacyInferenceLinkDir = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir(
packagePair.InstalledPackage.FullPath "output",
).JoinDir("output", "Inference"); "Inference"
);
if (legacyInferenceLinkDir.Exists) if (legacyInferenceLinkDir.Exists)
{ {
logger.LogInformation( logger.LogInformation("Deleting legacy inference link at {LegacyDir}", legacyInferenceLinkDir);
"Deleting legacy inference link at {LegacyDir}",
legacyInferenceLinkDir
);
if (legacyInferenceLinkDir.IsSymbolicLink) if (legacyInferenceLinkDir.IsSymbolicLink)
{ {
@ -462,10 +409,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task ConnectAsync( public async Task ConnectAsync(PackagePair packagePair, CancellationToken cancellationToken = default)
PackagePair packagePair,
CancellationToken cancellationToken = default
)
{ {
if (IsConnected) if (IsConnected)
return; return;

14
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -271,13 +271,17 @@ public abstract partial class InferenceGenerationViewModelBase : InferenceTabVie
Task.Run( Task.Run(
async () => async () =>
{ {
var delayTime = 250 - (int)timer.ElapsedMilliseconds; try
if (delayTime > 0)
{ {
await Task.Delay(delayTime, cancellationToken); var delayTime = 250 - (int)timer.ElapsedMilliseconds;
if (delayTime > 0)
{
await Task.Delay(delayTime, cancellationToken);
}
// ReSharper disable once AccessToDisposedClosure
AttachRunningNodeChangedHandler(promptTask);
} }
// ReSharper disable once AccessToDisposedClosure catch (TaskCanceledException) { }
AttachRunningNodeChangedHandler(promptTask);
}, },
cancellationToken cancellationToken
) )

3
StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs

@ -124,7 +124,8 @@ public partial class CivitAiBrowserViewModel : TabViewModelBase
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0) .Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString()); .OrderBy(t => t.ToString());
public List<string> BaseModelOptions => new() { "All", "SD 1.5", "SD 2.1", "SDXL 0.9", "SDXL 1.0" }; public List<string> BaseModelOptions =>
["All", "SD 1.5", "SD 1.5 LCM", "SD 2.1", "SDXL 0.9", "SDXL 1.0", "SDXL 1.0 LCM", "SDXL Turbo", "Other"];
public CivitAiBrowserViewModel( public CivitAiBrowserViewModel(
ICivitApi civitApi, ICivitApi civitApi,

51
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -16,6 +16,8 @@ using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using Size = System.Drawing.Size;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration #pragma warning disable CS0657 // Not a valid attribute location for this declaration
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -109,6 +111,8 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
Width, Width,
Height Height
); );
e.Builder.Connections.PrimarySize = new Size(Width, Height);
} }
// Provide temp values // Provide temp values
@ -147,6 +151,53 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
e.Builder.Connections.PrimaryScheduler = e.Builder.Connections.PrimaryScheduler =
SelectedScheduler ?? throw new ValidationException("Scheduler not selected"); SelectedScheduler ?? throw new ValidationException("Scheduler not selected");
// Use custom sampler if SDTurbo scheduler is selected
if (e.Builder.Connections.PrimaryScheduler == ComfyScheduler.SDTurbo)
{
// Error if using refiner
if (e.Builder.Connections.RefinerModel is not null)
{
throw new ValidationException("SDTurbo Scheduler cannot be used with Refiner Model");
}
var kSamplerSelect = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.KSamplerSelect
{
Name = "KSamplerSelect",
SamplerName = e.Builder.Connections.PrimarySampler?.Name!
}
);
var turboScheduler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.SDTurboScheduler
{
Name = "SDTurboScheduler",
Model = e.Builder.Connections.BaseModel ?? throw new ArgumentException("No BaseModel"),
Steps = Steps
}
);
var sampler = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.SamplerCustom
{
Name = "Sampler",
Model = e.Builder.Connections.BaseModel ?? throw new ArgumentException("No BaseModel"),
AddNoise = true,
NoiseSeed = e.Builder.Connections.Seed,
Cfg = CfgScale,
Positive = e.Temp.Conditioning?.Positive!,
Negative = e.Temp.Conditioning?.Negative!,
Sampler = kSamplerSelect.Output,
Sigmas = turboScheduler.Output,
LatentImage = primaryLatent
}
);
e.Builder.Connections.Primary = sampler.Output1;
return;
}
// Use KSampler if no refiner, otherwise need KSamplerAdvanced // Use KSampler if no refiner, otherwise need KSamplerAdvanced
if (e.Builder.Connections.RefinerModel is null) if (e.Builder.Connections.RefinerModel is null)
{ {

23
StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs

@ -8,7 +8,6 @@ using System.Text.Json.Serialization;
using System.Threading.Tasks; using System.Threading.Tasks;
using AsyncAwaitBestPractices; using AsyncAwaitBestPractices;
using Avalonia.Input; using Avalonia.Input;
using Avalonia.Media;
using Avalonia.Platform.Storage; using Avalonia.Platform.Storage;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
@ -43,40 +42,36 @@ public partial class SelectImageCardViewModel(INotificationService notificationS
[NotifyPropertyChangedFor(nameof(IsImageFileNotFound))] [NotifyPropertyChangedFor(nameof(IsImageFileNotFound))]
private ImageSource? imageSource; private ImageSource? imageSource;
[ObservableProperty]
[property: JsonIgnore]
[NotifyPropertyChangedFor(nameof(CurrentBitmapSize))]
private IImage? currentBitmap;
[ObservableProperty] [ObservableProperty]
[property: JsonIgnore] [property: JsonIgnore]
[NotifyPropertyChangedFor(nameof(IsSelectionAvailable))] [NotifyPropertyChangedFor(nameof(IsSelectionAvailable))]
private bool isSelectionEnabled = true; private bool isSelectionEnabled = true;
/// <summary>
/// Set by <see cref="SelectImageCard"/> when the image is loaded.
/// </summary>
[ObservableProperty]
private Size currentBitmapSize = Size.Empty;
/// <summary> /// <summary>
/// True if the image file is set but the local file does not exist. /// True if the image file is set but the local file does not exist.
/// </summary> /// </summary>
[MemberNotNullWhen(true, nameof(NotFoundImagePath))] [MemberNotNullWhen(true, nameof(NotFoundImagePath))]
public bool IsImageFileNotFound => ImageSource?.LocalFile?.Exists == false; public bool IsImageFileNotFound => ImageSource?.LocalFile?.Exists == false;
public bool IsSelectionAvailable => IsSelectionEnabled && ImageSource == null && !IsImageFileNotFound; public bool IsSelectionAvailable => IsSelectionEnabled && ImageSource == null;
/// <summary> /// <summary>
/// Path of the not found image /// Path of the not found image
/// </summary> /// </summary>
public string? NotFoundImagePath => ImageSource?.LocalFile?.FullPath; public string? NotFoundImagePath => ImageSource?.LocalFile?.FullPath;
public Size? CurrentBitmapSize =>
CurrentBitmap is null
? null
: new Size(Convert.ToInt32(CurrentBitmap.Size.Width), Convert.ToInt32(CurrentBitmap.Size.Height));
/// <inheritdoc /> /// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e) public void ApplyStep(ModuleApplyStepEventArgs e)
{ {
e.Builder.SetupImagePrimarySource( e.Builder.SetupImagePrimarySource(
ImageSource ?? throw new ValidationException("Input Image is required"), ImageSource ?? throw new ValidationException("Input Image is required"),
CurrentBitmapSize ?? throw new ValidationException("Input Image is required"), !CurrentBitmapSize.IsEmpty ? CurrentBitmapSize : throw new ValidationException("CurrentBitmapSize is null"),
e.Builder.Connections.BatchIndex e.Builder.Connections.BatchIndex
); );
} }
@ -108,7 +103,7 @@ public partial class SelectImageCardViewModel(INotificationService notificationS
if (files.FirstOrDefault()?.TryGetLocalPath() is { } path) if (files.FirstOrDefault()?.TryGetLocalPath() is { } path)
{ {
LoadUserImageSafe(new ImageSource(path)); Dispatcher.UIThread.Post(() => LoadUserImageSafe(new ImageSource(path)));
} }
} }

15
StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs

@ -41,8 +41,8 @@ namespace StabilityMatrix.Avalonia.ViewModels;
[Preload] [Preload]
[View(typeof(InferencePage))] [View(typeof(InferencePage))]
[Singleton] [Singleton, Singleton(typeof(IAsyncDisposable))]
public partial class InferenceViewModel : PageViewModelBase public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
@ -230,16 +230,13 @@ public partial class InferenceViewModel : PageViewModelBase
} }
/// <summary> /// <summary>
/// On Unloaded, sync tab states to database /// On exit, sync tab states to database
/// </summary> /// </summary>
public override async Task OnUnloadedAsync() public async ValueTask DisposeAsync()
{ {
await base.OnUnloadedAsync();
if (Design.IsDesignMode)
return;
await SyncTabStatesWithDatabase(); await SyncTabStatesWithDatabase();
GC.SuppressFinalize(this);
} }
private void OnInferenceTextToImageRequested(object? sender, LocalImageFile e) private void OnInferenceTextToImageRequested(object? sender, LocalImageFile e)

11
StabilityMatrix.Core/Models/Api/Comfy/ComfySampler.cs

@ -4,11 +4,15 @@ using System.Diagnostics.CodeAnalysis;
namespace StabilityMatrix.Core.Models.Api.Comfy; namespace StabilityMatrix.Core.Models.Api.Comfy;
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[SuppressMessage("ReSharper", "StringLiteralTypo")]
[SuppressMessage("ReSharper", "InconsistentNaming")]
[SuppressMessage("ReSharper", "IdentifierTypo")]
public readonly record struct ComfySampler(string Name) public readonly record struct ComfySampler(string Name)
{ {
public static ComfySampler Euler { get; } = new("euler"); public static ComfySampler Euler { get; } = new("euler");
public static ComfySampler EulerAncestral { get; } = new("euler_ancestral"); public static ComfySampler EulerAncestral { get; } = new("euler_ancestral");
public static ComfySampler Heun { get; } = new("heun"); public static ComfySampler Heun { get; } = new("heun");
public static ComfySampler HeunPp2 { get; } = new("heunpp2");
public static ComfySampler Dpm2 { get; } = new("dpm_2"); public static ComfySampler Dpm2 { get; } = new("dpm_2");
public static ComfySampler Dpm2Ancestral { get; } = new("dpm_2_ancestral"); public static ComfySampler Dpm2Ancestral { get; } = new("dpm_2_ancestral");
public static ComfySampler LMS { get; } = new("lms"); public static ComfySampler LMS { get; } = new("lms");
@ -24,8 +28,10 @@ public readonly record struct ComfySampler(string Name)
public static ComfySampler Dpmpp3MSde { get; } = new("dpmpp_3m_sde"); public static ComfySampler Dpmpp3MSde { get; } = new("dpmpp_3m_sde");
public static ComfySampler Dpmpp3MSdeGpu { get; } = new("dpmpp_3m_sde_gpu"); public static ComfySampler Dpmpp3MSdeGpu { get; } = new("dpmpp_3m_sde_gpu");
public static ComfySampler DDIM { get; } = new("ddim"); public static ComfySampler DDIM { get; } = new("ddim");
public static ComfySampler DDPM { get; } = new("ddpm");
public static ComfySampler UniPC { get; } = new("uni_pc"); public static ComfySampler UniPC { get; } = new("uni_pc");
public static ComfySampler UniPCBh2 { get; } = new("uni_pc_bh2"); public static ComfySampler UniPCBh2 { get; } = new("uni_pc_bh2");
public static ComfySampler LCM { get; } = new("lcm");
private static Dictionary<ComfySampler, string> ConvertDict { get; } = private static Dictionary<ComfySampler, string> ConvertDict { get; } =
new() new()
@ -33,6 +39,7 @@ public readonly record struct ComfySampler(string Name)
[Euler] = "Euler", [Euler] = "Euler",
[EulerAncestral] = "Euler Ancestral", [EulerAncestral] = "Euler Ancestral",
[Heun] = "Heun", [Heun] = "Heun",
[HeunPp2] = "Heun++ 2",
[Dpm2] = "DPM 2", [Dpm2] = "DPM 2",
[Dpm2Ancestral] = "DPM 2 Ancestral", [Dpm2Ancestral] = "DPM 2 Ancestral",
[LMS] = "LMS", [LMS] = "LMS",
@ -48,8 +55,10 @@ public readonly record struct ComfySampler(string Name)
[Dpmpp3MSde] = "DPM++ 3M SDE", [Dpmpp3MSde] = "DPM++ 3M SDE",
[Dpmpp3MSdeGpu] = "DPM++ 3M SDE GPU", [Dpmpp3MSdeGpu] = "DPM++ 3M SDE GPU",
[DDIM] = "DDIM", [DDIM] = "DDIM",
[DDPM] = "DDPM",
[UniPC] = "UniPC", [UniPC] = "UniPC",
[UniPCBh2] = "UniPC BH2" [UniPCBh2] = "UniPC BH2",
[LCM] = "LCM"
}; };
public static IReadOnlyList<ComfySampler> Defaults { get; } = ConvertDict.Keys.ToImmutableArray(); public static IReadOnlyList<ComfySampler> Defaults { get; } = ConvertDict.Keys.ToImmutableArray();

11
StabilityMatrix.Core/Models/Api/Comfy/ComfyScheduler.cs

@ -7,23 +7,24 @@ public readonly record struct ComfyScheduler(string Name)
public static ComfyScheduler Normal { get; } = new("normal"); public static ComfyScheduler Normal { get; } = new("normal");
public static ComfyScheduler Karras { get; } = new("karras"); public static ComfyScheduler Karras { get; } = new("karras");
public static ComfyScheduler Exponential { get; } = new("exponential"); public static ComfyScheduler Exponential { get; } = new("exponential");
public static ComfyScheduler SDTurbo { get; } = new("sd_turbo");
private static Dictionary<string, string> ConvertDict { get; } = private static Dictionary<string, string> ConvertDict { get; } =
new() new()
{ {
[Normal.Name] = "Normal", [Normal.Name] = "Normal",
["karras"] = "Karras", [Karras.Name] = "Karras",
["exponential"] = "Exponential", [Exponential.Name] = "Exponential",
["sgm_uniform"] = "SGM Uniform", ["sgm_uniform"] = "SGM Uniform",
["simple"] = "Simple", ["simple"] = "Simple",
["ddim_uniform"] = "DDIM Uniform" ["ddim_uniform"] = "DDIM Uniform",
[SDTurbo.Name] = "SD Turbo"
}; };
public static IReadOnlyList<ComfyScheduler> Defaults { get; } = public static IReadOnlyList<ComfyScheduler> Defaults { get; } =
ConvertDict.Keys.Select(k => new ComfyScheduler(k)).ToImmutableArray(); ConvertDict.Keys.Select(k => new ComfyScheduler(k)).ToImmutableArray();
public string DisplayName => public string DisplayName => ConvertDict.GetValueOrDefault(Name, Name);
ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name;
private sealed class NameEqualityComparer : IEqualityComparer<ComfyScheduler> private sealed class NameEqualityComparer : IEqualityComparer<ComfyScheduler>
{ {

4
StabilityMatrix.Core/Models/Api/Comfy/NodeTypes/NodeConnections.cs

@ -17,3 +17,7 @@ public class ConditioningNodeConnection : NodeConnectionBase { }
public class ClipNodeConnection : NodeConnectionBase { } public class ClipNodeConnection : NodeConnectionBase { }
public class ControlNetNodeConnection : NodeConnectionBase { } public class ControlNetNodeConnection : NodeConnectionBase { }
public class SamplerNodeConnection : NodeConnectionBase { }
public class SigmasNodeConnection : NodeConnectionBase { }

227
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -1,4 +1,5 @@
using System.Diagnostics.CodeAnalysis; using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Drawing; using System.Drawing;
using System.Runtime.Serialization; using System.Runtime.Serialization;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
@ -26,9 +27,7 @@ public class ComfyNodeBuilder
{ {
if (i > 1_000_000) if (i > 1_000_000)
{ {
throw new InvalidOperationException( throw new InvalidOperationException($"Could not find unique name for base {nameBase}");
$"Could not find unique name for base {nameBase}"
);
} }
name = $"{nameBase}_{i + 1}"; name = $"{nameBase}_{i + 1}";
@ -63,39 +62,6 @@ public class ComfyNodeBuilder
public required double Denoise { get; init; } public required double Denoise { get; init; }
} }
/*public static NamedComfyNode<LatentNodeConnection> KSampler(
string name,
ModelNodeConnection model,
ulong seed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
double denoise
)
{
return new NamedComfyNode<LatentNodeConnection>(name)
{
ClassType = "KSampler",
Inputs = new Dictionary<string, object?>
{
["model"] = model.Data,
["seed"] = seed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,
["denoise"] = denoise
}
};
}*/
public record KSamplerAdvanced : ComfyTypedNodeBase<LatentNodeConnection> public record KSamplerAdvanced : ComfyTypedNodeBase<LatentNodeConnection>
{ {
public required ModelNodeConnection Model { get; init; } public required ModelNodeConnection Model { get; init; }
@ -117,44 +83,34 @@ public class ComfyNodeBuilder
public bool ReturnWithLeftoverNoise { get; init; } public bool ReturnWithLeftoverNoise { get; init; }
} }
/*public static NamedComfyNode<LatentNodeConnection> KSamplerAdvanced( public record SamplerCustom : ComfyTypedNodeBase<LatentNodeConnection, LatentNodeConnection>
string name,
ModelNodeConnection model,
bool addNoise,
ulong noiseSeed,
int steps,
double cfg,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
int startAtStep,
int endAtStep,
bool returnWithLeftoverNoise
)
{ {
return new NamedComfyNode<LatentNodeConnection>(name) public required ModelNodeConnection Model { get; init; }
{ public required bool AddNoise { get; init; }
ClassType = "KSamplerAdvanced", public required ulong NoiseSeed { get; init; }
Inputs = new Dictionary<string, object?>
{ [Range(0d, 100d)]
["model"] = model.Data, public required double Cfg { get; init; }
["add_noise"] = addNoise ? "enable" : "disable",
["noise_seed"] = noiseSeed, public required ConditioningNodeConnection Positive { get; init; }
["steps"] = steps, public required ConditioningNodeConnection Negative { get; init; }
["cfg"] = cfg, public required SamplerNodeConnection Sampler { get; init; }
["sampler_name"] = sampler.Name, public required SigmasNodeConnection Sigmas { get; init; }
["scheduler"] = scheduler.Name, public required LatentNodeConnection LatentImage { get; init; }
["positive"] = positive.Data, }
["negative"] = negative.Data,
["latent_image"] = latentImage.Data, public record KSamplerSelect : ComfyTypedNodeBase<SamplerNodeConnection>
["start_at_step"] = startAtStep, {
["end_at_step"] = endAtStep, public required string SamplerName { get; init; }
["return_with_leftover_noise"] = returnWithLeftoverNoise ? "enable" : "disable" }
}
}; public record SDTurboScheduler : ComfyTypedNodeBase<SigmasNodeConnection>
}*/ {
public required ModelNodeConnection Model { get; init; }
[Range(1, 10)]
public required int Steps { get; init; }
}
public record EmptyLatentImage : ComfyTypedNodeBase<LatentNodeConnection> public record EmptyLatentImage : ComfyTypedNodeBase<LatentNodeConnection>
{ {
@ -191,18 +147,11 @@ public class ComfyNodeBuilder
return new NamedComfyNode<ImageNodeConnection>(name) return new NamedComfyNode<ImageNodeConnection>(name)
{ {
ClassType = "ImageUpscaleWithModel", ClassType = "ImageUpscaleWithModel",
Inputs = new Dictionary<string, object?> Inputs = new Dictionary<string, object?> { ["upscale_model"] = upscaleModel.Data, ["image"] = image.Data }
{
["upscale_model"] = upscaleModel.Data,
["image"] = image.Data
}
}; };
} }
public static NamedComfyNode<UpscaleModelNodeConnection> UpscaleModelLoader( public static NamedComfyNode<UpscaleModelNodeConnection> UpscaleModelLoader(string name, string modelName)
string name,
string modelName
)
{ {
return new NamedComfyNode<UpscaleModelNodeConnection>(name) return new NamedComfyNode<UpscaleModelNodeConnection>(name)
{ {
@ -323,8 +272,7 @@ public class ComfyNodeBuilder
public required string ControlNetName { get; init; } public required string ControlNetName { get; init; }
} }
public record ControlNetApplyAdvanced public record ControlNetApplyAdvanced : ComfyTypedNodeBase<ConditioningNodeConnection, ConditioningNodeConnection>
: ComfyTypedNodeBase<ConditioningNodeConnection, ConditioningNodeConnection>
{ {
public required ConditioningNodeConnection Positive { get; init; } public required ConditioningNodeConnection Positive { get; init; }
public required ConditioningNodeConnection Negative { get; init; } public required ConditioningNodeConnection Negative { get; init; }
@ -335,10 +283,7 @@ public class ComfyNodeBuilder
public required double EndPercent { get; init; } public required double EndPercent { get; init; }
} }
public ImageNodeConnection Lambda_LatentToImage( public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
LatentNodeConnection latent,
VAENodeConnection vae
)
{ {
var name = GetUniqueName("VAEDecode"); var name = GetUniqueName("VAEDecode");
return Nodes return Nodes
@ -353,10 +298,7 @@ public class ComfyNodeBuilder
.Output; .Output;
} }
public LatentNodeConnection Lambda_ImageToLatent( public LatentNodeConnection Lambda_ImageToLatent(ImageNodeConnection pixels, VAENodeConnection vae)
ImageNodeConnection pixels,
VAENodeConnection vae
)
{ {
var name = GetUniqueName("VAEEncode"); var name = GetUniqueName("VAEEncode");
return Nodes return Nodes
@ -380,9 +322,7 @@ public class ComfyNodeBuilder
ImageNodeConnection image ImageNodeConnection image
) )
{ {
var modelLoader = Nodes.AddNamedNode( var modelLoader = Nodes.AddNamedNode(UpscaleModelLoader($"{name}_UpscaleModelLoader", modelName));
UpscaleModelLoader($"{name}_UpscaleModelLoader", modelName)
);
var upscaler = Nodes.AddNamedNode( var upscaler = Nodes.AddNamedNode(
ImageUpscaleWithModel($"{name}_ImageUpscaleWithModel", modelLoader.Output, image) ImageUpscaleWithModel($"{name}_ImageUpscaleWithModel", modelLoader.Output, image)
@ -425,16 +365,7 @@ public class ComfyNodeBuilder
.Output, .Output,
image => image =>
Nodes Nodes
.AddNamedNode( .AddNamedNode(ImageScale($"{name}_ImageUpscale", image, upscaleInfo.Name, height, width, false))
ImageScale(
$"{name}_ImageUpscale",
image,
upscaleInfo.Name,
height,
width,
false
)
)
.Output .Output
); );
} }
@ -445,22 +376,11 @@ public class ComfyNodeBuilder
var samplerImage = GetPrimaryAsImage(primary, vae); var samplerImage = GetPrimaryAsImage(primary, vae);
// Do group upscale // Do group upscale
var modelUpscaler = Group_UpscaleWithModel( var modelUpscaler = Group_UpscaleWithModel($"{name}_ModelUpscale", upscaleInfo.Name, samplerImage);
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size // Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode( var resizedScaled = Nodes.AddNamedNode(
ImageScale( ImageScale($"{name}_ImageScale", modelUpscaler.Output, "bilinear", height, width, false)
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
); );
return resizedScaled.Output; return resizedScaled.Output;
@ -512,22 +432,11 @@ public class ComfyNodeBuilder
); );
// Do group upscale // Do group upscale
var modelUpscaler = Group_UpscaleWithModel( var modelUpscaler = Group_UpscaleWithModel($"{name}_ModelUpscale", upscaleInfo.Name, samplerImage.Output);
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size // Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode( var resizedScaled = Nodes.AddNamedNode(
ImageScale( ImageScale($"{name}_ImageScale", modelUpscaler.Output, "bilinear", height, width, false)
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
); );
// Convert back to latent space // Convert back to latent space
@ -597,22 +506,11 @@ public class ComfyNodeBuilder
); );
// Do group upscale // Do group upscale
var modelUpscaler = Group_UpscaleWithModel( var modelUpscaler = Group_UpscaleWithModel($"{name}_ModelUpscale", upscaleInfo.Name, samplerImage.Output);
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size // Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode( var resizedScaled = Nodes.AddNamedNode(
ImageScale( ImageScale($"{name}_ImageScale", modelUpscaler.Output, "bilinear", height, width, false)
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
); );
// No need to convert back to latent space // No need to convert back to latent space
@ -654,22 +552,11 @@ public class ComfyNodeBuilder
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN) if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{ {
// Do group upscale // Do group upscale
var modelUpscaler = Group_UpscaleWithModel( var modelUpscaler = Group_UpscaleWithModel($"{name}_ModelUpscale", upscaleInfo.Name, image);
$"{name}_ModelUpscale",
upscaleInfo.Name,
image
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size // Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = Nodes.AddNamedNode( var resizedScaled = Nodes.AddNamedNode(
ImageScale( ImageScale($"{name}_ImageScale", modelUpscaler.Output, "bilinear", height, width, false)
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
); );
// No need to convert back to latent space // No need to convert back to latent space
@ -764,10 +651,7 @@ public class ComfyNodeBuilder
/// <summary> /// <summary>
/// Get or convert latest primary connection to latent /// Get or convert latest primary connection to latent
/// </summary> /// </summary>
public LatentNodeConnection GetPrimaryAsLatent( public LatentNodeConnection GetPrimaryAsLatent(PrimaryNodeConnection primary, VAENodeConnection vae)
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{ {
return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae)); return primary.Match(latent => latent, image => Lambda_ImageToLatent(image, vae));
} }
@ -807,10 +691,7 @@ public class ComfyNodeBuilder
/// <summary> /// <summary>
/// Get or convert latest primary connection to image /// Get or convert latest primary connection to image
/// </summary> /// </summary>
public ImageNodeConnection GetPrimaryAsImage( public ImageNodeConnection GetPrimaryAsImage(PrimaryNodeConnection primary, VAENodeConnection vae)
PrimaryNodeConnection primary,
VAENodeConnection vae
)
{ {
return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image); return primary.Match(latent => Lambda_LatentToImage(latent, vae), image => image);
} }
@ -825,10 +706,7 @@ public class ComfyNodeBuilder
return Connections.Primary.AsT1; return Connections.Primary.AsT1;
} }
return GetPrimaryAsImage( return GetPrimaryAsImage(Connections.Primary ?? throw new NullReferenceException("No primary connection"), vae);
Connections.Primary ?? throw new NullReferenceException("No primary connection"),
vae
);
} }
/// <summary> /// <summary>
@ -878,9 +756,7 @@ public class ComfyNodeBuilder
public ConditioningNodeConnection GetRefinerOrBaseConditioning() public ConditioningNodeConnection GetRefinerOrBaseConditioning()
{ {
return RefinerConditioning return RefinerConditioning ?? BaseConditioning ?? throw new NullReferenceException("No Conditioning");
?? BaseConditioning
?? throw new NullReferenceException("No Conditioning");
} }
public ConditioningNodeConnection GetRefinerOrBaseNegativeConditioning() public ConditioningNodeConnection GetRefinerOrBaseNegativeConditioning()
@ -892,10 +768,7 @@ public class ComfyNodeBuilder
public VAENodeConnection GetDefaultVAE() public VAENodeConnection GetDefaultVAE()
{ {
return PrimaryVAE return PrimaryVAE ?? RefinerVAE ?? BaseVAE ?? throw new NullReferenceException("No VAE");
?? RefinerVAE
?? BaseVAE
?? throw new NullReferenceException("No VAE");
} }
} }

Loading…
Cancel
Save