Browse Source

Add scheduler selection

pull/165/head
Ionite 1 year ago
parent
commit
4eb36cc3d0
No known key found for this signature in database
  1. 27
      StabilityMatrix.Avalonia/Controls/SamplerCard.axaml
  2. 35
      StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
  3. 10
      StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
  4. 27
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  5. 19
      StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs
  6. 45
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  7. 76
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
  8. 38
      StabilityMatrix.Core/Models/Api/Comfy/ComfyScheduler.cs
  9. 8
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

27
StabilityMatrix.Avalonia/Controls/SamplerCard.axaml

@ -21,7 +21,7 @@
<StackPanel
HorizontalAlignment="{TemplateBinding HorizontalAlignment}"
Spacing="8">
<Grid ColumnDefinitions="Auto,*" RowDefinitions="*,*,*">
<Grid ColumnDefinitions="Auto,*" RowDefinitions="*,*,*,*">
<!-- Sampler -->
<TextBlock
Grid.Row="0"
@ -39,15 +39,32 @@
DisplayMemberBinding="{Binding DisplayName}"
Margin="8,0,0,8"
HorizontalAlignment="Stretch"/>
<!-- Steps -->
<!-- Scheduler -->
<TextBlock
Grid.Row="1"
Grid.Column="0"
IsVisible="{Binding IsSchedulerSelectionEnabled}"
Margin="0,0,0,8"
VerticalAlignment="Center"
Text="Scheduler" />
<ui:FAComboBox
Grid.Row="1"
Grid.Column="1"
IsVisible="{Binding IsSchedulerSelectionEnabled}"
ItemsSource="{Binding ClientManager.Schedulers}"
SelectedItem="{Binding SelectedScheduler}"
DisplayMemberBinding="{Binding DisplayName}"
Margin="8,0,0,8"
HorizontalAlignment="Stretch"/>
<!-- Steps -->
<TextBlock
Grid.Row="2"
Grid.Column="0"
Margin="0,0,0,8"
VerticalAlignment="Center"
Text="Steps" />
<ui:NumberBox
Grid.Row="1"
Grid.Row="2"
Grid.Column="1"
SelectionHighlightColor="Transparent"
Value="{Binding Steps}"
@ -56,13 +73,13 @@
SpinButtonPlacementMode="Inline"/>
<!-- CFG Scale -->
<TextBlock
Grid.Row="2"
Grid.Row="3"
Grid.Column="0"
IsVisible="{Binding IsCfgScaleEnabled}"
VerticalAlignment="Center"
Text="CFG Scale" />
<ui:NumberBox
Grid.Row="2"
Grid.Row="3"
Grid.Column="1"
IsVisible="{Binding IsCfgScaleEnabled}"
SelectionHighlightColor="Transparent"

35
StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

@ -13,34 +13,31 @@ namespace StabilityMatrix.Avalonia.DesignData;
public class MockInferenceClientManager : ObservableObject, IInferenceClientManager
{
public ComfyClient? Client { get; set; }
public IObservableCollection<HybridModelFile> Models { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>(new ComfySampler[]
{
new("euler_ancestral"),
new("euler"),
new("lms"),
new("heun"),
new("dpm_2"),
new("dpm_2_ancestral")
});
new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults);
public IObservableCollection<ComfyUpscaler> Upscalers { get; } =
new ObservableCollectionExtended<ComfyUpscaler>(new ComfyUpscaler[]
{
new("nearest-exact", ComfyUpscalerType.Latent),
new("bicubic", ComfyUpscalerType.Latent),
new("ESRGAN-4x", ComfyUpscalerType.ESRGAN)
});
new ObservableCollectionExtended<ComfyUpscaler>(
new ComfyUpscaler[]
{
new("nearest-exact", ComfyUpscalerType.Latent),
new("bicubic", ComfyUpscalerType.Latent),
new("ESRGAN-4x", ComfyUpscalerType.ESRGAN)
}
);
public IObservableCollection<ComfyScheduler> Schedulers { get; } =
new ObservableCollectionExtended<ComfyScheduler>(ComfyScheduler.Defaults);
public bool IsConnected { get; set; }
public Task ConnectAsync()
{
return Task.CompletedTask;
@ -56,7 +53,7 @@ public class MockInferenceClientManager : ObservableObject, IInferenceClientMana
{
return Task.CompletedTask;
}
public void Dispose()
{
GC.SuppressFinalize(this);

10
StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

@ -10,10 +10,13 @@ using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Avalonia.Services;
public interface IInferenceClientManager : IDisposable, INotifyPropertyChanged, INotifyPropertyChanging
public interface IInferenceClientManager
: IDisposable,
INotifyPropertyChanged,
INotifyPropertyChanging
{
ComfyClient? Client { get; set; }
[MemberNotNullWhen(true, nameof(Client))]
bool IsConnected { get; }
@ -21,7 +24,8 @@ public interface IInferenceClientManager : IDisposable, INotifyPropertyChanged,
IObservableCollection<HybridModelFile> VaeModels { get; }
IObservableCollection<ComfySampler> Samplers { get; }
IObservableCollection<ComfyUpscaler> Upscalers { get; }
IObservableCollection<ComfyScheduler> Schedulers { get; }
Task ConnectAsync();
Task ConnectAsync(PackagePair packagePair);

27
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -64,6 +64,11 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
public IObservableCollection<ComfyUpscaler> Upscalers { get; } =
new ObservableCollectionExtended<ComfyUpscaler>();
private readonly SourceCache<ComfyScheduler, string> schedulersSource = new(p => p.Name);
public IObservableCollection<ComfyScheduler> Schedulers { get; } =
new ObservableCollectionExtended<ComfyScheduler>();
public InferenceClientManager(
ILogger<InferenceClientManager> logger,
IApiFactory apiFactory,
@ -77,7 +82,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
modelsSource.Connect().DeferUntilLoaded().Bind(Models).Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
vaeModelsDefaults
.Connect()
.Or(vaeModelsSource.Connect())
@ -94,6 +99,8 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
.Bind(Upscalers)
.Subscribe();
schedulersSource.Connect().DeferUntilLoaded().Bind(Schedulers).Subscribe();
ResetSharedProperties();
}
@ -141,10 +148,22 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
{ } modelUpscalerNames
)
{
modelUpscalersSource.EditDiff(modelUpscalerNames.Select(
s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)), ComfyUpscaler.Comparer);
modelUpscalersSource.EditDiff(
modelUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)),
ComfyUpscaler.Comparer
);
logger.LogTrace("Loaded model upscale methods: {@Upscalers}", modelUpscalerNames);
}
// Add scheduler names from Scheduler node
if (await Client.GetNodeOptionNamesAsync("KSampler", "scheduler") is { } schedulerNames)
{
schedulersSource.EditDiff(
schedulerNames.Select(s => new ComfyScheduler(s)),
ComfyScheduler.Comparer
);
logger.LogTrace("Loaded scheduler methods: {@Schedulers}", schedulerNames);
}
}
/// <summary>
@ -181,6 +200,8 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
latentUpscalersSource.EditDiff(ComfyUpscaler.Defaults, ComfyUpscaler.Comparer);
modelUpscalersSource.Clear();
schedulersSource.EditDiff(ComfyScheduler.Defaults, ComfyScheduler.Comparer);
}
public async Task ConnectAsync()

19
StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs

@ -1,9 +1,11 @@
using System;
using System.Collections.Specialized;
using System.IO;
using System.Threading.Tasks;
using Avalonia.Collections;
using Avalonia.Media;
using Avalonia.Media.Imaging;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using NLog;
@ -65,6 +67,23 @@ public partial class ImageGalleryCardViewModel : ViewModelBase
ImageSources.CollectionChanged += OnImageSourcesItemsChanged;
}
public void SetPreviewImage(byte[] imageBytes)
{
Dispatcher.UIThread.Post(() =>
{
using var stream = new MemoryStream(imageBytes);
using var bitmap = new Bitmap(stream);
var currentImage = PreviewImage;
PreviewImage = bitmap;
IsPreviewOverlayEnabled = true;
currentImage?.Dispose();
});
}
private void OnImageSourcesItemsChanged(object? sender, NotifyCollectionChangedEventArgs e)
{
if (sender is AvaloniaList<ImageSource> sources)

45
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq;
using System.Text.Json.Serialization;
@ -95,7 +96,13 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Model Card
vmFactory.Get<ModelCardViewModel>(),
// Sampler
vmFactory.Get<SamplerCardViewModel>(),
vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
samplerCard.IsDimensionsEnabled = true;
samplerCard.IsCfgScaleEnabled = true;
samplerCard.IsSamplerSelectionEnabled = true;
samplerCard.IsSchedulerSelectionEnabled = true;
}),
// Hires Fix
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
@ -108,9 +115,6 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Hires Fix Sampler
vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
samplerCard.IsDimensionsEnabled = false;
samplerCard.IsCfgScaleEnabled = false;
samplerCard.IsSamplerSelectionEnabled = false;
samplerCard.IsDenoiseStrengthEnabled = true;
})
}
@ -257,9 +261,10 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
Convert.ToUInt64(seedCard.Seed),
samplerCard.Steps,
samplerCard.CfgScale,
samplerCard.SelectedSampler?.Name
?? throw new InvalidOperationException("Sampler not selected"),
"normal",
samplerCard.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
samplerCard.SelectedScheduler
?? throw new ValidationException("Sampler not selected"),
positiveClip.GetOutput<ConditioningNodeConnection>(0),
negativeClip.GetOutput<ConditioningNodeConnection>(0),
emptyLatentImage.GetOutput<LatentNodeConnection>(0),
@ -337,11 +342,13 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
Convert.ToUInt64(seedCard.Seed),
hiresSamplerCard.Steps,
hiresSamplerCard.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler name
hiresSamplerCard.SelectedSampler?.Name
?? samplerCard.SelectedSampler?.Name
?? throw new InvalidOperationException("Sampler not selected"),
"normal",
// Use hires sampler name if not null, otherwise use the normal sampler
hiresSamplerCard.SelectedSampler
?? samplerCard.SelectedSampler
?? throw new ValidationException("Sampler not selected"),
hiresSamplerCard.SelectedScheduler
?? samplerCard.SelectedScheduler
?? throw new ValidationException("Scheduler not selected"),
positiveClip.GetOutput<ConditioningNodeConnection>(0),
negativeClip.GetOutput<ConditioningNodeConnection>(0),
hiresLatent,
@ -405,19 +412,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
Dispatcher.UIThread.Post(() =>
{
using var stream = new MemoryStream(args.ImageBytes);
var bitmap = new Bitmap(stream);
var currentImage = ImageGalleryCardViewModel.PreviewImage;
ImageGalleryCardViewModel.PreviewImage = bitmap;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = true;
currentImage?.Dispose();
});
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
private async Task GenerateImageImpl(

76
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -14,23 +14,42 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SamplerCard))]
public partial class SamplerCardViewModel : LoadableViewModelBase
{
[ObservableProperty] private int steps = 20;
[ObservableProperty]
private int steps = 20;
[ObservableProperty]
private bool isDenoiseStrengthEnabled;
[ObservableProperty] private bool isDenoiseStrengthEnabled;
[ObservableProperty] private double denoiseStrength = 1;
[ObservableProperty] private bool isCfgScaleEnabled = true;
[ObservableProperty] private double cfgScale = 7;
[ObservableProperty]
private double denoiseStrength = 1;
[ObservableProperty]
private bool isCfgScaleEnabled;
[ObservableProperty] private bool isDimensionsEnabled = true;
[ObservableProperty] private int width = 512;
[ObservableProperty] private int height = 512;
[ObservableProperty] private bool isSamplerSelectionEnabled = true;
[ObservableProperty]
private double cfgScale = 7;
[ObservableProperty]
private bool isDimensionsEnabled;
[ObservableProperty]
private int width = 512;
[ObservableProperty]
private int height = 512;
[ObservableProperty]
private bool isSamplerSelectionEnabled;
[ObservableProperty]
private ComfySampler? selectedSampler = new ComfySampler("euler_ancestral");
[ObservableProperty]
private bool isSchedulerSelectionEnabled;
[ObservableProperty]
private ComfyScheduler? selectedScheduler = new ComfyScheduler("normal");
public IInferenceClientManager ClientManager { get; }
public SamplerCardViewModel(IInferenceClientManager clientManager)
@ -42,7 +61,7 @@ public partial class SamplerCardViewModel : LoadableViewModelBase
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel<SamplerCardModel>(state);
Steps = model.Steps;
IsDenoiseStrengthEnabled = model.IsDenoiseStrengthEnabled;
DenoiseStrength = model.DenoiseStrength;
@ -52,25 +71,28 @@ public partial class SamplerCardViewModel : LoadableViewModelBase
Width = model.Width;
Height = model.Height;
IsSamplerSelectionEnabled = model.IsSamplerSelectionEnabled;
SelectedSampler = model.SelectedSampler is null ? null
SelectedSampler = model.SelectedSampler is null
? null
: new ComfySampler(model.SelectedSampler);
}
/// <inheritdoc />
public override JsonObject SaveStateToJsonObject()
{
return SerializeModel(new SamplerCardModel
{
Steps = Steps,
IsDenoiseStrengthEnabled = IsDenoiseStrengthEnabled,
DenoiseStrength = DenoiseStrength,
IsCfgScaleEnabled = IsCfgScaleEnabled,
CfgScale = CfgScale,
IsDimensionsEnabled = IsDimensionsEnabled,
Width = Width,
Height = Height,
IsSamplerSelectionEnabled = IsSamplerSelectionEnabled,
SelectedSampler = SelectedSampler?.Name
});
return SerializeModel(
new SamplerCardModel
{
Steps = Steps,
IsDenoiseStrengthEnabled = IsDenoiseStrengthEnabled,
DenoiseStrength = DenoiseStrength,
IsCfgScaleEnabled = IsCfgScaleEnabled,
CfgScale = CfgScale,
IsDimensionsEnabled = IsDimensionsEnabled,
Width = Width,
Height = Height,
IsSamplerSelectionEnabled = IsSamplerSelectionEnabled,
SelectedSampler = SelectedSampler?.Name
}
);
}
}

38
StabilityMatrix.Core/Models/Api/Comfy/ComfyScheduler.cs

@ -0,0 +1,38 @@
using System.Collections.Immutable;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public readonly record struct ComfyScheduler(string Name)
{
private static Dictionary<string, string> ConvertDict { get; } =
new()
{
["normal"] = "Normal",
["karras"] = "Karras",
["exponential"] = "Exponential",
["sgm_uniform"] = "SGM Uniform",
["simple"] = "Simple",
["ddim_uniform"] = "DDIM Uniform"
};
public static IReadOnlyList<ComfyScheduler> Defaults { get; } =
ConvertDict.Keys.Select(k => new ComfyScheduler(k)).ToImmutableArray();
public string DisplayName =>
ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name;
private sealed class NameEqualityComparer : IEqualityComparer<ComfyScheduler>
{
public bool Equals(ComfyScheduler x, ComfyScheduler y)
{
return x.Name == y.Name;
}
public int GetHashCode(ComfyScheduler obj)
{
return obj.Name.GetHashCode();
}
}
public static IEqualityComparer<ComfyScheduler> Comparer { get; } = new NameEqualityComparer();
}

8
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -60,8 +60,8 @@ public class ComfyNodeBuilder
ulong seed,
int steps,
double cfg,
string samplerName,
string scheduler,
ComfySampler sampler,
ComfyScheduler scheduler,
ConditioningNodeConnection positive,
ConditioningNodeConnection negative,
LatentNodeConnection latentImage,
@ -77,8 +77,8 @@ public class ComfyNodeBuilder
["seed"] = seed,
["steps"] = steps,
["cfg"] = cfg,
["sampler_name"] = samplerName,
["scheduler"] = scheduler,
["sampler_name"] = sampler.Name,
["scheduler"] = scheduler.Name,
["positive"] = positive.Data,
["negative"] = negative.Data,
["latent_image"] = latentImage.Data,

Loading…
Cancel
Save