Browse Source

Abstract common generation functionality to InferenceGenerationViewModelBase

pull/165/head
Ionite 1 year ago
parent
commit
11cb46abf8
No known key found for this signature in database
  1. 3
      StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs
  2. 6
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  3. 24
      StabilityMatrix.Avalonia/Models/InferenceProjectType.cs
  4. 93
      StabilityMatrix.Avalonia/Services/ServiceManager.cs
  5. 354
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  6. 284
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  7. 19
      StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
  8. 13
      StabilityMatrix.Avalonia/Views/InferencePage.axaml
  9. 10
      StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs
  10. 5
      StabilityMatrix.Core/Helper/EventManager.cs
  11. 5
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  12. 2
      StabilityMatrix.Core/Services/IImageIndexService.cs
  13. 10
      StabilityMatrix.Core/Services/ImageIndexService.cs

3
StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs

@ -54,9 +54,6 @@ public class MockImageIndexService : IImageIndexService
return Task.CompletedTask; return Task.CompletedTask;
} }
/// <inheritdoc />
public void OnImageAdded(FilePath filePath) { }
/// <inheritdoc /> /// <inheritdoc />
public void BackgroundRefreshIndex() public void BackgroundRefreshIndex()
{ {

6
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -253,11 +253,13 @@ public static class ComfyNodeBuilderExtensions
var vaeDecoder = builder.Nodes.AddNamedNode( var vaeDecoder = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.VAEDecode( ComfyNodeBuilder.VAEDecode(
"VAEDecode", "VAEDecode",
builder.Connections.Latent!, builder.Connections.Latent
?? throw new InvalidOperationException("Latent source not set"),
builder.Connections.GetRefinerOrBaseVAE() builder.Connections.GetRefinerOrBaseVAE()
) )
); );
builder.Connections.Image = vaeDecoder.Output; builder.Connections.Image = vaeDecoder.Output;
builder.Connections.ImageSize = builder.Connections.LatentSize;
} }
var saveImage = builder.Nodes.AddNamedNode( var saveImage = builder.Nodes.AddNamedNode(
@ -272,6 +274,8 @@ public static class ComfyNodeBuilderExtensions
} }
); );
builder.Connections.OutputNodes.Add(saveImage);
return saveImage.Name; return saveImage.Name;
} }
} }

24
StabilityMatrix.Avalonia/Models/InferenceProjectType.cs

@ -1,7 +1,29 @@
namespace StabilityMatrix.Avalonia.Models; using System;
using StabilityMatrix.Avalonia.ViewModels.Inference;
namespace StabilityMatrix.Avalonia.Models;
public enum InferenceProjectType public enum InferenceProjectType
{ {
Unknown, Unknown,
TextToImage, TextToImage,
ImageToImage,
Inpainting,
Upscale
}
public static class InferenceProjectTypeExtensions
{
public static Type? ToViewModelType(this InferenceProjectType type)
{
return type switch
{
InferenceProjectType.TextToImage => typeof(InferenceTextToImageViewModel),
InferenceProjectType.ImageToImage => null,
InferenceProjectType.Inpainting => null,
InferenceProjectType.Upscale => typeof(InferenceImageUpscaleViewModel),
InferenceProjectType.Unknown => null,
_ => throw new ArgumentOutOfRangeException(nameof(type), type, null)
};
}
} }

93
StabilityMatrix.Avalonia/Services/ServiceManager.cs

@ -20,16 +20,19 @@ public class ServiceManager<T>
/// <summary> /// <summary>
/// Register a new dialog view model (singleton instance) /// Register a new dialog view model (singleton instance)
/// </summary> /// </summary>
public ServiceManager<T> Register<TService>(TService instance) where TService : T public ServiceManager<T> Register<TService>(TService instance)
where TService : T
{ {
if (instance is null) throw new ArgumentNullException(nameof(instance)); if (instance is null)
throw new ArgumentNullException(nameof(instance));
lock (instances) lock (instances)
{ {
if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService))) if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService)))
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {typeof(TService)} is already registered for {typeof(T)}"); $"Service of type {typeof(TService)} is already registered for {typeof(T)}"
);
} }
instances[instance.GetType()] = instance; instances[instance.GetType()] = instance;
@ -41,14 +44,16 @@ public class ServiceManager<T>
/// <summary> /// <summary>
/// Register a new dialog view model provider action (called on each dialog creation) /// Register a new dialog view model provider action (called on each dialog creation)
/// </summary> /// </summary>
public ServiceManager<T> Register<TService>(Func<TService> provider) where TService : T public ServiceManager<T> Register<TService>(Func<TService> provider)
where TService : T
{ {
lock (providers) lock (providers)
{ {
if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService))) if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService)))
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {typeof(TService)} is already registered for {typeof(T)}"); $"Service of type {typeof(TService)} is already registered for {typeof(T)}"
);
} }
// Return type is wrong during build with method group syntax // Return type is wrong during build with method group syntax
@ -63,14 +68,16 @@ public class ServiceManager<T>
/// Register a new dialog view model instance using a service provider /// Register a new dialog view model instance using a service provider
/// Equal to Register[TService](serviceProvider.GetRequiredService[TService]) /// Equal to Register[TService](serviceProvider.GetRequiredService[TService])
/// </summary> /// </summary>
public ServiceManager<T> RegisterProvider<TService>(IServiceProvider provider) where TService : notnull, T public ServiceManager<T> RegisterProvider<TService>(IServiceProvider provider)
where TService : notnull, T
{ {
lock (providers) lock (providers)
{ {
if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService))) if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService)))
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {typeof(TService)} is already registered for {typeof(T)}"); $"Service of type {typeof(TService)} is already registered for {typeof(T)}"
);
} }
// Return type is wrong during build with method group syntax // Return type is wrong during build with method group syntax
@ -87,10 +94,11 @@ public class ServiceManager<T>
[SuppressMessage("ReSharper", "InconsistentlySynchronizedField")] [SuppressMessage("ReSharper", "InconsistentlySynchronizedField")]
public T Get(Type serviceType) public T Get(Type serviceType)
{ {
if (!serviceType.IsAssignableFrom(typeof(T))) if (!serviceType.IsAssignableTo(typeof(T)))
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service type {serviceType} is not assignable from {typeof(T)}"); $"Service type {serviceType} is not assignable to {typeof(T)}"
);
} }
if (instances.TryGetValue(serviceType, out var instance)) if (instances.TryGetValue(serviceType, out var instance))
@ -98,7 +106,8 @@ public class ServiceManager<T>
if (instance is null) if (instance is null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {serviceType} was registered as null"); $"Service of type {serviceType} was registered as null"
);
} }
return (T)instance; return (T)instance;
} }
@ -108,33 +117,38 @@ public class ServiceManager<T>
if (provider is null) if (provider is null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {serviceType} was registered as null"); $"Service of type {serviceType} was registered as null"
);
} }
var result = provider(); var result = provider();
if (result is null) if (result is null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service provider for type {serviceType} returned null"); $"Service provider for type {serviceType} returned null"
);
} }
return (T)result; return (T)result;
} }
throw new ArgumentException( throw new ArgumentException(
$"Service of type {serviceType} is not registered for {typeof(T)}"); $"Service of type {serviceType} is not registered for {typeof(T)}"
);
} }
/// <summary> /// <summary>
/// Get a view model instance /// Get a view model instance
/// </summary> /// </summary>
[SuppressMessage("ReSharper", "InconsistentlySynchronizedField")] [SuppressMessage("ReSharper", "InconsistentlySynchronizedField")]
public TService Get<TService>() where TService : T public TService Get<TService>()
where TService : T
{ {
if (instances.TryGetValue(typeof(TService), out var instance)) if (instances.TryGetValue(typeof(TService), out var instance))
{ {
if (instance is null) if (instance is null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {typeof(TService)} was registered as null"); $"Service of type {typeof(TService)} was registered as null"
);
} }
return (TService)instance; return (TService)instance;
} }
@ -144,25 +158,29 @@ public class ServiceManager<T>
if (provider is null) if (provider is null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service of type {typeof(TService)} was registered as null"); $"Service of type {typeof(TService)} was registered as null"
);
} }
var result = provider(); var result = provider();
if (result is null) if (result is null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Service provider for type {typeof(TService)} returned null"); $"Service provider for type {typeof(TService)} returned null"
);
} }
return (TService)result; return (TService)result;
} }
throw new ArgumentException( throw new ArgumentException(
$"Service of type {typeof(TService)} is not registered for {typeof(T)}"); $"Service of type {typeof(TService)} is not registered for {typeof(T)}"
);
} }
/// <summary> /// <summary>
/// Get a view model instance with an initializer parameter /// Get a view model instance with an initializer parameter
/// </summary> /// </summary>
public TService Get<TService>(Func<TService, TService> initializer) where TService : T public TService Get<TService>(Func<TService, TService> initializer)
where TService : T
{ {
var instance = Get<TService>(); var instance = Get<TService>();
return initializer(instance); return initializer(instance);
@ -171,7 +189,8 @@ public class ServiceManager<T>
/// <summary> /// <summary>
/// Get a view model instance with an initializer for a mutable instance /// Get a view model instance with an initializer for a mutable instance
/// </summary> /// </summary>
public TService Get<TService>(Action<TService> initializer) where TService : T public TService Get<TService>(Action<TService> initializer)
where TService : T
{ {
var instance = Get<TService>(); var instance = Get<TService>();
initializer(instance); initializer(instance);
@ -182,19 +201,26 @@ public class ServiceManager<T>
/// Get a view model instance, set as DataContext of its View, and return /// Get a view model instance, set as DataContext of its View, and return
/// a BetterContentDialog with that View as its Content /// a BetterContentDialog with that View as its Content
/// </summary> /// </summary>
public BetterContentDialog GetDialog<TService>() where TService : T public BetterContentDialog GetDialog<TService>()
where TService : T
{ {
var instance = Get<TService>()!; var instance = Get<TService>()!;
if (Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute)) is not ViewAttribute if (
viewAttr) Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute))
is not ViewAttribute viewAttr
)
{ {
throw new InvalidOperationException($"View not found for {instance.GetType().FullName}"); throw new InvalidOperationException(
$"View not found for {instance.GetType().FullName}"
);
} }
if (Activator.CreateInstance(viewAttr.ViewType) is not Control view) if (Activator.CreateInstance(viewAttr.ViewType) is not Control view)
{ {
throw new NullReferenceException($"Unable to create instance for {instance.GetType().FullName}"); throw new NullReferenceException(
$"Unable to create instance for {instance.GetType().FullName}"
);
} }
return new BetterContentDialog { Content = view }; return new BetterContentDialog { Content = view };
@ -204,19 +230,26 @@ public class ServiceManager<T>
/// Get a view model instance with initializer, set as DataContext of its View, and return /// Get a view model instance with initializer, set as DataContext of its View, and return
/// a BetterContentDialog with that View as its Content /// a BetterContentDialog with that View as its Content
/// </summary> /// </summary>
public BetterContentDialog GetDialog<TService>(Action<TService> initializer) where TService : T public BetterContentDialog GetDialog<TService>(Action<TService> initializer)
where TService : T
{ {
var instance = Get(initializer)!; var instance = Get(initializer)!;
if (Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute)) is not ViewAttribute if (
viewAttr) Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute))
is not ViewAttribute viewAttr
)
{ {
throw new InvalidOperationException($"View not found for {instance.GetType().FullName}"); throw new InvalidOperationException(
$"View not found for {instance.GetType().FullName}"
);
} }
if (Activator.CreateInstance(viewAttr.ViewType) is not Control view) if (Activator.CreateInstance(viewAttr.ViewType) is not Control view)
{ {
throw new NullReferenceException($"Unable to create instance for {instance.GetType().FullName}"); throw new NullReferenceException(
$"Unable to create instance for {instance.GetType().FullName}"
);
} }
view.DataContext = instance; view.DataContext = instance;

354
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -0,0 +1,354 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
/// <summary>
/// Abstract base class for tab view models that generate images using ClientManager.
/// This includes a progress reporter, image output view model, and generation virtual methods.
/// </summary>
[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
public abstract partial class InferenceGenerationViewModelBase
: InferenceTabViewModelBase,
IImageGalleryComponent
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
[JsonPropertyName("ImageGallery")]
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
[JsonIgnore]
public ImageFolderCardViewModel ImageFolderCardViewModel { get; }
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
/// <inheritdoc />
protected InferenceGenerationViewModelBase(
ServiceManager<ViewModelBase> vmFactory,
IInferenceClientManager inferenceClientManager,
INotificationService notificationService
)
{
this.notificationService = notificationService;
ClientManager = inferenceClientManager;
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
ImageFolderCardViewModel = vmFactory.Get<ImageFolderCardViewModel>();
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
/// <summary>
/// Builds the image generation prompt
/// </summary>
protected virtual void BuildPrompt(BuildPromptEventArgs args) { }
/// <summary>
/// Runs a generation task
/// </summary>
/// <exception cref="InvalidOperationException">Thrown if args.Parameters or args.Project are null</exception>
protected async Task RunGeneration(
ImageGenerationEventArgs args,
CancellationToken cancellationToken
)
{
var client = args.Client;
var nodes = args.Nodes;
// Checks
if (args.Parameters is null)
throw new InvalidOperationException("Parameters is null");
if (args.Project is null)
throw new InvalidOperationException("Project is null");
if (args.OutputNodeNames.Count == 0)
throw new InvalidOperationException("OutputNodeNames is empty");
if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null");
// Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
{
// Register to interrupt if user cancels
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
promptTask.Id,
cancellationToken
);
ImageGalleryCardViewModel.ImageSources.Clear();
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
await ProcessOutputImages(images, args);
}
finally
{
// Disconnect progress handler
client.PreviewImageReceived -= OnPreviewImageReceived;
// Clear progress
OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
// Cleanup tasks
promptTask?.Dispose();
}
}
/// <summary>
/// Handles image output metadata for generation runs
/// </summary>
private async Task ProcessOutputImages(
IEnumerable<ComfyImage> images,
ImageGenerationEventArgs args
)
{
// Write metadata to images
var outputImages = new List<ImageSource>();
foreach (
var filePath in images.Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
)
{
var bytesWithMetadata = PngDataHelper.AddMetadata(
await filePath.ReadAllBytesAsync(),
args.Parameters!,
args.Project!
);
await using (var outputStream = filePath.Info.OpenWrite())
{
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
}
outputImages.Add(new ImageSource(filePath));
EventManager.Instance.OnImageFileAdded(filePath);
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImages
.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead()))
.ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
args.Parameters!,
args.Project!
);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = args.Client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using (var fileStream = gridPath.Info.OpenWrite())
{
await fileStream.WriteAsync(gridBytesWithMetadata);
}
// Insert to start of images
var gridImage = new ImageSource(gridPath);
// Preload
await gridImage.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(gridImage);
EventManager.Instance.OnImageFileAdded(gridPath);
}
// Add rest of images
foreach (var img in outputImages)
{
// Preload
await img.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(img);
}
}
/// <summary>
/// Implementation for Generate Image
/// </summary>
protected virtual Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
{
return Task.CompletedTask;
}
/// <summary>
/// Command for the Generate Image button
/// </summary>
/// <param name="options">Optional overrides (side buttons)</param>
/// <param name="cancellationToken">Cancellation token</param>
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
GenerateFlags options = default,
CancellationToken cancellationToken = default
)
{
try
{
var overrides = GenerateOverrides.FromFlags(options);
await GenerateImageImpl(overrides, cancellationToken);
}
catch (OperationCanceledException)
{
Logger.Debug($"Image Generation Canceled");
}
}
/// <summary>
/// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery.
/// </summary>
protected virtual void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
/// <summary>
/// Handles the progress update received event from the websocket.
/// Updates the progress view model.
/// </summary>
protected virtual void OnProgressUpdateReceived(
object? sender,
ComfyProgressUpdateEventArgs args
)
{
Dispatcher.UIThread.Post(() =>
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
public class ImageGenerationEventArgs : EventArgs
{
public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList<string> OutputNodeNames { get; init; }
public GenerationParameters? Parameters { get; set; }
public InferenceProjectDocument? Project { get; set; }
}
public class BuildPromptEventArgs : EventArgs
{
public ComfyNodeBuilder Builder { get; } = new();
public GenerateOverrides Overrides { get; set; } = new();
}
[Flags]
public enum GenerateFlags
{
None = 0,
HiresFixEnable = 1 << 1,
HiresFixDisable = 1 << 2,
UseCurrentSeed = 1 << 3,
UseRandomSeed = 1 << 4
}
public class GenerateOverrides
{
public bool? IsHiresFixEnabled { get; set; }
public bool? UseCurrentSeed { get; set; }
public static GenerateOverrides FromFlags(GenerateFlags flags)
{
var overrides = new GenerateOverrides()
{
IsHiresFixEnabled = flags.HasFlag(GenerateFlags.HiresFixEnable)
? true
: flags.HasFlag(GenerateFlags.HiresFixDisable)
? false
: null,
UseCurrentSeed = flags.HasFlag(GenerateFlags.UseCurrentSeed)
? true
: flags.HasFlag(GenerateFlags.UseRandomSeed)
? false
: null
};
return overrides;
}
}
}

284
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -1,35 +1,21 @@
using System; using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq; using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using DynamicData.Binding; using DynamicData.Binding;
using NLog; using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView; using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView;
@ -38,19 +24,12 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), persistent: true)] [View(typeof(InferenceTextToImageView), persistent: true)]
public partial class InferenceTextToImageViewModel public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase
: InferenceTabViewModelBase,
IImageGalleryComponent
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
private readonly IModelIndexService modelIndexService; private readonly IModelIndexService modelIndexService;
private readonly IImageIndexService imageIndexService;
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
[JsonIgnore] [JsonIgnore]
public StackCardViewModel StackCardViewModel { get; } public StackCardViewModel StackCardViewModel { get; }
@ -61,12 +40,6 @@ public partial class InferenceTextToImageViewModel
[JsonPropertyName("Sampler")] [JsonPropertyName("Sampler")]
public SamplerCardViewModel SamplerCardViewModel { get; } public SamplerCardViewModel SamplerCardViewModel { get; }
[JsonPropertyName("ImageGallery")]
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
[JsonPropertyName("ImageFolder")]
public ImageFolderCardViewModel ImageFolderCardViewModel { get; }
[JsonPropertyName("Prompt")] [JsonPropertyName("Prompt")]
public PromptCardViewModel PromptCardViewModel { get; } public PromptCardViewModel PromptCardViewModel { get; }
@ -97,26 +70,16 @@ public partial class InferenceTextToImageViewModel
set => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled = value; set => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled = value;
} }
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
[ObservableProperty]
[property: JsonIgnore]
private string? outputImageSource;
public InferenceTextToImageViewModel( public InferenceTextToImageViewModel(
INotificationService notificationService, INotificationService notificationService,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
ServiceManager<ViewModelBase> vmFactory, ServiceManager<ViewModelBase> vmFactory,
IModelIndexService modelIndexService, IModelIndexService modelIndexService
IImageIndexService imageIndexService
) )
: base(vmFactory, inferenceClientManager, notificationService)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;
this.vmFactory = vmFactory;
this.modelIndexService = modelIndexService; this.modelIndexService = modelIndexService;
this.imageIndexService = imageIndexService;
ClientManager = inferenceClientManager;
// Get sub view models from service manager // Get sub view models from service manager
@ -133,8 +96,6 @@ public partial class InferenceTextToImageViewModel
samplerCard.IsSchedulerSelectionEnabled = true; samplerCard.IsSchedulerSelectionEnabled = true;
}); });
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
ImageFolderCardViewModel = vmFactory.Get<ImageFolderCardViewModel>();
PromptCardViewModel = vmFactory.Get<PromptCardViewModel>(); PromptCardViewModel = vmFactory.Get<PromptCardViewModel>();
HiresSamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard => HiresSamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{ {
@ -181,17 +142,16 @@ public partial class InferenceTextToImageViewModel
SamplerCardViewModel.IsRefinerStepsEnabled = SamplerCardViewModel.IsRefinerStepsEnabled =
e.Sender is { IsRefinerSelectionEnabled: true, SelectedRefiner: not null }; e.Sender is { IsRefinerSelectionEnabled: true, SelectedRefiner: not null };
}); });
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
} }
private (NodeDictionary prompt, string[] outputs) BuildPrompt( /// <inheritdoc />
GenerateOverrides? overrides = null protected override void BuildPrompt(BuildPromptEventArgs args)
)
{ {
using var _ = new CodeTimer(); base.BuildPrompt(args);
var builder = new ComfyNodeBuilder(); using var _ = CodeTimer.StartDebug();
var builder = args.Builder;
var nodes = builder.Nodes; var nodes = builder.Nodes;
// Setup empty latent // Setup empty latent
@ -232,7 +192,7 @@ public partial class InferenceTextToImageViewModel
} }
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node // If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
if (overrides?.IsHiresFixEnabled ?? IsHiresFixEnabled) if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
{ {
// Requested upscale to this size // Requested upscale to this size
var hiresSize = builder.Connections.GetScaledLatentSize( var hiresSize = builder.Connections.GetScaledLatentSize(
@ -309,35 +269,12 @@ public partial class InferenceTextToImageViewModel
// Set as the image output // Set as the image output
builder.Connections.Image = postUpscaleGroup.Output; builder.Connections.Image = postUpscaleGroup.Output;
} }
// Output
var outputName = builder.SetupOutputImage();
return (builder.ToNodeDictionary(), new[] { outputName });
} }
private void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args) /// <inheritdoc />
{ protected override async Task GenerateImageImpl(
Dispatcher.UIThread.Post(() => GenerateOverrides overrides,
{ CancellationToken cancellationToken
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
private async Task GenerateImageImpl(
GenerateOverrides? overrides = null,
CancellationToken cancellationToken = default
) )
{ {
// Validate the prompts // Validate the prompts
@ -359,192 +296,33 @@ public partial class InferenceTextToImageViewModel
seedCard.GenerateNewSeed(); seedCard.GenerateNewSeed();
} }
var client = ClientManager.Client; var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides };
BuildPrompt(buildPromptArgs);
var (nodes, outputNodeNames) = BuildPrompt(overrides); var generationArgs = new ImageGenerationEventArgs
{
var generationInfo = new GenerationParameters Client = ClientManager.Client,
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
Parameters = new GenerationParameters
{ {
Seed = (ulong)seedCard.Seed, Seed = (ulong)seedCard.Seed,
Steps = SamplerCardViewModel.Steps, Steps = SamplerCardViewModel.Steps,
CfgScale = SamplerCardViewModel.CfgScale, CfgScale = SamplerCardViewModel.CfgScale,
Sampler = SamplerCardViewModel.SelectedSampler?.Name, Sampler = SamplerCardViewModel.SelectedSampler?.Name,
ModelName = ModelCardViewModel.SelectedModelName, ModelName = ModelCardViewModel.SelectedModelName,
// TODO: ModelHash ModelHash = ModelCardViewModel
.SelectedModel
?.Local
?.ConnectedModelInfo
?.Hashes
.SHA256,
PositivePrompt = PromptCardViewModel.PromptDocument.Text, PositivePrompt = PromptCardViewModel.PromptDocument.Text,
NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text
}; },
var smproj = InferenceProjectDocument.FromLoadable(this); Project = InferenceProjectDocument.FromLoadable(this)
// Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
{
// Register to interrupt if user cancels
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
promptTask.Id,
cancellationToken
);
ImageGalleryCardViewModel.ImageSources.Clear();
if (!imageOutputs.TryGetValue(outputNodeNames[0], out var images) || images is null)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
List<ImageSource> outputImages;
// Use local file path if available, otherwise use remote URL
if (client.OutputImagesDir is { } outputPath)
{
outputImages = new List<ImageSource>();
foreach (var image in images)
{
var filePath = image.ToFilePath(outputPath);
var bytesWithMetadata = PngDataHelper.AddMetadata(
await filePath.ReadAllBytesAsync(),
generationInfo,
smproj
);
/*await using (var readStream = filePath.Info.OpenWrite())
{
using (var reader = new BinaryReader(readStream))
{
}
}*/
await using (var outputStream = filePath.Info.OpenWrite())
{
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
}
outputImages.Add(new ImageSource(filePath));
imageIndexService.OnImageAdded(filePath);
}
}
else
{
outputImages = images!
.Select(i => new ImageSource(i.ToUri(client.BaseAddress)))
.ToList();
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImages
.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead()))
.ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
generationInfo,
smproj
);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using (var fileStream = gridPath.Info.OpenWrite())
{
await fileStream.WriteAsync(gridBytesWithMetadata, cancellationToken);
}
// Insert to start of images
var gridImage = new ImageSource(gridPath);
// Preload
await gridImage.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(gridImage);
imageIndexService.OnImageAdded(gridPath);
}
// Add rest of images
foreach (var img in outputImages)
{
// Preload
await img.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(img);
}
}
finally
{
// Disconnect progress handler
OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
promptTask?.Dispose();
client.PreviewImageReceived -= OnPreviewImageReceived;
}
}
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
string? options = null,
CancellationToken cancellationToken = default
)
{
try
{
var overrides = new GenerateOverrides
{
IsHiresFixEnabled = options?.Contains("hires_fix"),
UseCurrentSeed = options?.Contains("current_seed")
}; };
await GenerateImageImpl(overrides, cancellationToken); await RunGeneration(generationArgs, cancellationToken);
}
catch (OperationCanceledException e)
{
Logger.Debug($"[Image Generation Canceled] {e.Message}");
}
}
internal class GenerateOverrides
{
public bool? IsHiresFixEnabled { get; set; }
public bool? UseCurrentSeed { get; set; }
} }
} }

19
StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs

@ -70,6 +70,8 @@ public partial class InferenceViewModel : PageViewModelBase
public IInferenceClientManager ClientManager { get; } public IInferenceClientManager ClientManager { get; }
public SharedState SharedState { get; }
public ObservableCollection<InferenceTabViewModelBase> Tabs { get; } = new(); public ObservableCollection<InferenceTabViewModelBase> Tabs { get; } = new();
[ObservableProperty] [ObservableProperty]
@ -93,7 +95,8 @@ public partial class InferenceViewModel : PageViewModelBase
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IModelIndexService modelIndexService, IModelIndexService modelIndexService,
ILiteDbContext liteDbContext ILiteDbContext liteDbContext,
SharedState sharedState
) )
{ {
this.vmFactory = vmFactory; this.vmFactory = vmFactory;
@ -103,6 +106,7 @@ public partial class InferenceViewModel : PageViewModelBase
this.liteDbContext = liteDbContext; this.liteDbContext = liteDbContext;
ClientManager = inferenceClientManager; ClientManager = inferenceClientManager;
SharedState = sharedState;
// Keep RunningPackage updated with the current package pair // Keep RunningPackage updated with the current package pair
EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged; EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged;
@ -240,8 +244,10 @@ public partial class InferenceViewModel : PageViewModelBase
continue; continue;
} }
var projectPath = projectFile.ToString();
var entry = await liteDbContext.InferenceProjects.FindOneAsync( var entry = await liteDbContext.InferenceProjects.FindOneAsync(
p => p.FilePath == projectFile.ToString() p => p.FilePath == projectPath
); );
// Create if not found // Create if not found
@ -299,9 +305,14 @@ public partial class InferenceViewModel : PageViewModelBase
/// When the + button on the tab control is clicked, add a new tab. /// When the + button on the tab control is clicked, add a new tab.
/// </summary> /// </summary>
[RelayCommand] [RelayCommand]
private void AddTab() public void AddTab(InferenceProjectType type = InferenceProjectType.TextToImage)
{
if (type.ToViewModelType() is not { } vmType)
{ {
var tab = vmFactory.Get<InferenceTextToImageViewModel>(); return;
}
var tab = (InferenceTabViewModelBase)vmFactory.Get(vmType);
Tabs.Add(tab); Tabs.Add(tab);
// Set as new selected tab // Set as new selected tab

13
StabilityMatrix.Avalonia/Views/InferencePage.axaml

@ -40,7 +40,7 @@
<ui:CommandBarButton Label="Delete" ToolTip.Tip="Delete" />--> <ui:CommandBarButton Label="Delete" ToolTip.Tip="Delete" />-->
<ui:CommandBarFlyout.SecondaryCommands> <ui:CommandBarFlyout.SecondaryCommands>
<ui:CommandBarButton <ui:CommandBarButton
Click="AddTabMenu_TextToImageButton_OnClick" Click="AddTabMenu_TextToImage_OnClick"
IconSource="FullScreenMaximize" IconSource="FullScreenMaximize"
Label="Text to Image" Label="Text to Image"
ToolTip.Tip="Text to Image" /> ToolTip.Tip="Text to Image" />
@ -54,6 +54,17 @@
IsEnabled="False" IsEnabled="False"
Label="Inpaint" Label="Inpaint"
ToolTip.Tip="Inpaint" /> ToolTip.Tip="Inpaint" />
<ui:CommandBarButton
Click="AddTabMenu_Upscale_OnClick"
IsEnabled="{Binding SharedState.IsDebugMode}"
Label="Upscale"
ToolTip.Tip="Upscale">
<ui:CommandBarButton.IconSource>
<fluentIcons:SymbolIconSource
FontSize="10"
Symbol="ResizeImage"/>
</ui:CommandBarButton.IconSource>
</ui:CommandBarButton>
</ui:CommandBarFlyout.SecondaryCommands> </ui:CommandBarFlyout.SecondaryCommands>
</ui:CommandBarFlyout> </ui:CommandBarFlyout>
</controls:UserControlBase.Resources> </controls:UserControlBase.Resources>

10
StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs

@ -6,6 +6,7 @@ using Avalonia.Input;
using Avalonia.Interactivity; using Avalonia.Interactivity;
using FluentAvalonia.UI.Controls; using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.ViewModels; using StabilityMatrix.Avalonia.ViewModels;
namespace StabilityMatrix.Avalonia.Views; namespace StabilityMatrix.Avalonia.Views;
@ -47,8 +48,13 @@ public partial class InferencePage : UserControlBase
addTabFlyout.ShowAt(AddButton); addTabFlyout.ShowAt(AddButton);
} }
private void AddTabMenu_TextToImageButton_OnClick(object? sender, RoutedEventArgs e) private void AddTabMenu_TextToImage_OnClick(object? sender, RoutedEventArgs e)
{ {
(DataContext as InferenceViewModel)!.AddTabCommand.Execute(null); (DataContext as InferenceViewModel)!.AddTab();
}
private void AddTabMenu_Upscale_OnClick(object? sender, RoutedEventArgs e)
{
(DataContext as InferenceViewModel)!.AddTab(InferenceProjectType.Upscale);
} }
} }

5
StabilityMatrix.Core/Helper/EventManager.cs

@ -1,5 +1,6 @@
using System.Globalization; using System.Globalization;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification; using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Models.Update; using StabilityMatrix.Core.Models.Update;
@ -32,6 +33,8 @@ public class EventManager
public event EventHandler? ModelIndexChanged; public event EventHandler? ModelIndexChanged;
public event EventHandler<FilePath>? ImageFileAdded;
public void OnGlobalProgressChanged(int progress) => public void OnGlobalProgressChanged(int progress) =>
GlobalProgressChanged?.Invoke(this, progress); GlobalProgressChanged?.Invoke(this, progress);
@ -69,4 +72,6 @@ public class EventManager
public void OnCultureChanged(CultureInfo culture) => CultureChanged?.Invoke(this, culture); public void OnCultureChanged(CultureInfo culture) => CultureChanged?.Invoke(this, culture);
public void OnModelIndexChanged() => ModelIndexChanged?.Invoke(this, EventArgs.Empty); public void OnModelIndexChanged() => ModelIndexChanged?.Invoke(this, EventArgs.Empty);
public void OnImageFileAdded(FilePath filePath) => ImageFileAdded?.Invoke(this, filePath);
} }

5
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -529,6 +529,11 @@ public class ComfyNodeBuilder
public Size LatentSize { get; set; } public Size LatentSize { get; set; }
public ImageNodeConnection? Image { get; set; } public ImageNodeConnection? Image { get; set; }
public Size ImageSize { get; set; }
public List<NamedComfyNode> OutputNodes { get; } = new();
public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name);
/// <summary> /// <summary>
/// Gets the latent size scaled by a given factor /// Gets the latent size scaled by a given factor

2
StabilityMatrix.Core/Services/IImageIndexService.cs

@ -21,8 +21,6 @@ public interface IImageIndexService
Task RefreshIndex(IndexCollection<LocalImageFile, string> indexCollection); Task RefreshIndex(IndexCollection<LocalImageFile, string> indexCollection);
void OnImageAdded(FilePath filePath);
/// <summary> /// <summary>
/// Refreshes the index of local images in the background /// Refreshes the index of local images in the background
/// </summary> /// </summary>

10
StabilityMatrix.Core/Services/ImageIndexService.cs

@ -39,12 +39,7 @@ public class ImageIndexService : IImageIndexService
RelativePath = "inference" RelativePath = "inference"
}; };
/*inferenceImagesSource EventManager.Instance.ImageFileAdded += OnImageFileAdded;
.Connect()
.DeferUntilLoaded()
.SortBy(file => file.LastModifiedAt, SortDirection.Descending)
.Bind(InferenceImages)
.Subscribe();*/
} }
/// <inheritdoc /> /// <inheritdoc />
@ -125,8 +120,7 @@ public class ImageIndexService : IImageIndexService
); );
} }
/// <inheritdoc /> private void OnImageFileAdded(object? sender, FilePath filePath)
public void OnImageAdded(FilePath filePath)
{ {
var fullPath = settingsManager.ImagesDirectory.JoinDir(InferenceImages.RelativePath!); var fullPath = settingsManager.ImagesDirectory.JoinDir(InferenceImages.RelativePath!);

Loading…
Cancel
Save