Browse Source

Abstract common generation functionality to InferenceGenerationViewModelBase

pull/165/head
Ionite 1 year ago
parent
commit
11cb46abf8
No known key found for this signature in database
  1. 3
      StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs
  2. 6
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  3. 24
      StabilityMatrix.Avalonia/Models/InferenceProjectType.cs
  4. 143
      StabilityMatrix.Avalonia/Services/ServiceManager.cs
  5. 354
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  6. 298
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  7. 19
      StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
  8. 13
      StabilityMatrix.Avalonia/Views/InferencePage.axaml
  9. 10
      StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs
  10. 5
      StabilityMatrix.Core/Helper/EventManager.cs
  11. 5
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  12. 2
      StabilityMatrix.Core/Services/IImageIndexService.cs
  13. 10
      StabilityMatrix.Core/Services/ImageIndexService.cs

3
StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs

@ -54,9 +54,6 @@ public class MockImageIndexService : IImageIndexService
return Task.CompletedTask;
}
/// <inheritdoc />
public void OnImageAdded(FilePath filePath) { }
/// <inheritdoc />
public void BackgroundRefreshIndex()
{

6
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -253,11 +253,13 @@ public static class ComfyNodeBuilderExtensions
var vaeDecoder = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.VAEDecode(
"VAEDecode",
builder.Connections.Latent!,
builder.Connections.Latent
?? throw new InvalidOperationException("Latent source not set"),
builder.Connections.GetRefinerOrBaseVAE()
)
);
builder.Connections.Image = vaeDecoder.Output;
builder.Connections.ImageSize = builder.Connections.LatentSize;
}
var saveImage = builder.Nodes.AddNamedNode(
@ -272,6 +274,8 @@ public static class ComfyNodeBuilderExtensions
}
);
builder.Connections.OutputNodes.Add(saveImage);
return saveImage.Name;
}
}

24
StabilityMatrix.Avalonia/Models/InferenceProjectType.cs

@ -1,7 +1,29 @@
namespace StabilityMatrix.Avalonia.Models;
using System;
using StabilityMatrix.Avalonia.ViewModels.Inference;
namespace StabilityMatrix.Avalonia.Models;
public enum InferenceProjectType
{
Unknown,
TextToImage,
ImageToImage,
Inpainting,
Upscale
}
public static class InferenceProjectTypeExtensions
{
public static Type? ToViewModelType(this InferenceProjectType type)
{
return type switch
{
InferenceProjectType.TextToImage => typeof(InferenceTextToImageViewModel),
InferenceProjectType.ImageToImage => null,
InferenceProjectType.Inpainting => null,
InferenceProjectType.Upscale => typeof(InferenceImageUpscaleViewModel),
InferenceProjectType.Unknown => null,
_ => throw new ArgumentOutOfRangeException(nameof(type), type, null)
};
}
}

143
StabilityMatrix.Avalonia/Services/ServiceManager.cs

@ -13,94 +13,103 @@ public class ServiceManager<T>
{
// Holds providers
private readonly Dictionary<Type, Func<T>> providers = new();
// Holds singleton instances
private readonly Dictionary<Type, T> instances = new();
/// <summary>
/// Register a new dialog view model (singleton instance)
/// </summary>
public ServiceManager<T> Register<TService>(TService instance) where TService : T
public ServiceManager<T> Register<TService>(TService instance)
where TService : T
{
if (instance is null) throw new ArgumentNullException(nameof(instance));
if (instance is null)
throw new ArgumentNullException(nameof(instance));
lock (instances)
{
if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService)))
{
throw new ArgumentException(
$"Service of type {typeof(TService)} is already registered for {typeof(T)}");
$"Service of type {typeof(TService)} is already registered for {typeof(T)}"
);
}
instances[instance.GetType()] = instance;
}
return this;
}
/// <summary>
/// Register a new dialog view model provider action (called on each dialog creation)
/// </summary>
public ServiceManager<T> Register<TService>(Func<TService> provider) where TService : T
public ServiceManager<T> Register<TService>(Func<TService> provider)
where TService : T
{
lock (providers)
{
if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService)))
{
throw new ArgumentException(
$"Service of type {typeof(TService)} is already registered for {typeof(T)}");
$"Service of type {typeof(TService)} is already registered for {typeof(T)}"
);
}
// Return type is wrong during build with method group syntax
// ReSharper disable once RedundantCast
providers[typeof(TService)] = () => (TService) provider();
providers[typeof(TService)] = () => (TService)provider();
}
return this;
}
/// <summary>
/// Register a new dialog view model instance using a service provider
/// Equal to Register[TService](serviceProvider.GetRequiredService[TService])
/// </summary>
public ServiceManager<T> RegisterProvider<TService>(IServiceProvider provider) where TService : notnull, T
public ServiceManager<T> RegisterProvider<TService>(IServiceProvider provider)
where TService : notnull, T
{
lock (providers)
{
if (instances.ContainsKey(typeof(TService)) || providers.ContainsKey(typeof(TService)))
{
throw new ArgumentException(
$"Service of type {typeof(TService)} is already registered for {typeof(T)}");
$"Service of type {typeof(TService)} is already registered for {typeof(T)}"
);
}
// Return type is wrong during build with method group syntax
// ReSharper disable once RedundantCast
providers[typeof(TService)] = () => (TService) provider.GetRequiredService<TService>();
providers[typeof(TService)] = () => (TService)provider.GetRequiredService<TService>();
}
return this;
}
/// <summary>
/// Get a view model instance from runtime type
/// </summary>
[SuppressMessage("ReSharper", "InconsistentlySynchronizedField")]
public T Get(Type serviceType)
{
if (!serviceType.IsAssignableFrom(typeof(T)))
if (!serviceType.IsAssignableTo(typeof(T)))
{
throw new ArgumentException(
$"Service type {serviceType} is not assignable from {typeof(T)}");
$"Service type {serviceType} is not assignable to {typeof(T)}"
);
}
if (instances.TryGetValue(serviceType, out var instance))
{
if (instance is null)
{
throw new ArgumentException(
$"Service of type {serviceType} was registered as null");
$"Service of type {serviceType} was registered as null"
);
}
return (T) instance;
return (T)instance;
}
if (providers.TryGetValue(serviceType, out var provider))
@ -108,35 +117,40 @@ public class ServiceManager<T>
if (provider is null)
{
throw new ArgumentException(
$"Service of type {serviceType} was registered as null");
$"Service of type {serviceType} was registered as null"
);
}
var result = provider();
if (result is null)
{
throw new ArgumentException(
$"Service provider for type {serviceType} returned null");
$"Service provider for type {serviceType} returned null"
);
}
return (T) result;
return (T)result;
}
throw new ArgumentException(
$"Service of type {serviceType} is not registered for {typeof(T)}");
$"Service of type {serviceType} is not registered for {typeof(T)}"
);
}
/// <summary>
/// Get a view model instance
/// </summary>
[SuppressMessage("ReSharper", "InconsistentlySynchronizedField")]
public TService Get<TService>() where TService : T
public TService Get<TService>()
where TService : T
{
if (instances.TryGetValue(typeof(TService), out var instance))
{
if (instance is null)
{
throw new ArgumentException(
$"Service of type {typeof(TService)} was registered as null");
$"Service of type {typeof(TService)} was registered as null"
);
}
return (TService) instance;
return (TService)instance;
}
if (providers.TryGetValue(typeof(TService), out var provider))
@ -144,83 +158,102 @@ public class ServiceManager<T>
if (provider is null)
{
throw new ArgumentException(
$"Service of type {typeof(TService)} was registered as null");
$"Service of type {typeof(TService)} was registered as null"
);
}
var result = provider();
if (result is null)
{
throw new ArgumentException(
$"Service provider for type {typeof(TService)} returned null");
$"Service provider for type {typeof(TService)} returned null"
);
}
return (TService) result;
return (TService)result;
}
throw new ArgumentException(
$"Service of type {typeof(TService)} is not registered for {typeof(T)}");
$"Service of type {typeof(TService)} is not registered for {typeof(T)}"
);
}
/// <summary>
/// Get a view model instance with an initializer parameter
/// </summary>
public TService Get<TService>(Func<TService, TService> initializer) where TService : T
public TService Get<TService>(Func<TService, TService> initializer)
where TService : T
{
var instance = Get<TService>();
return initializer(instance);
}
/// <summary>
/// Get a view model instance with an initializer for a mutable instance
/// </summary>
public TService Get<TService>(Action<TService> initializer) where TService : T
public TService Get<TService>(Action<TService> initializer)
where TService : T
{
var instance = Get<TService>();
initializer(instance);
return instance;
}
/// <summary>
/// Get a view model instance, set as DataContext of its View, and return
/// a BetterContentDialog with that View as its Content
/// </summary>
public BetterContentDialog GetDialog<TService>() where TService : T
public BetterContentDialog GetDialog<TService>()
where TService : T
{
var instance = Get<TService>()!;
if (Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute)) is not ViewAttribute
viewAttr)
if (
Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute))
is not ViewAttribute viewAttr
)
{
throw new InvalidOperationException($"View not found for {instance.GetType().FullName}");
throw new InvalidOperationException(
$"View not found for {instance.GetType().FullName}"
);
}
if (Activator.CreateInstance(viewAttr.ViewType) is not Control view)
{
throw new NullReferenceException($"Unable to create instance for {instance.GetType().FullName}");
throw new NullReferenceException(
$"Unable to create instance for {instance.GetType().FullName}"
);
}
return new BetterContentDialog { Content = view };
}
/// <summary>
/// Get a view model instance with initializer, set as DataContext of its View, and return
/// a BetterContentDialog with that View as its Content
/// </summary>
public BetterContentDialog GetDialog<TService>(Action<TService> initializer) where TService : T
public BetterContentDialog GetDialog<TService>(Action<TService> initializer)
where TService : T
{
var instance = Get(initializer)!;
if (Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute)) is not ViewAttribute
viewAttr)
if (
Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute))
is not ViewAttribute viewAttr
)
{
throw new InvalidOperationException($"View not found for {instance.GetType().FullName}");
throw new InvalidOperationException(
$"View not found for {instance.GetType().FullName}"
);
}
if (Activator.CreateInstance(viewAttr.ViewType) is not Control view)
{
throw new NullReferenceException($"Unable to create instance for {instance.GetType().FullName}");
throw new NullReferenceException(
$"Unable to create instance for {instance.GetType().FullName}"
);
}
view.DataContext = instance;
return new BetterContentDialog { Content = view };
}
}

354
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -0,0 +1,354 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
/// <summary>
/// Abstract base class for tab view models that generate images using ClientManager.
/// This includes a progress reporter, image output view model, and generation virtual methods.
/// </summary>
[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
public abstract partial class InferenceGenerationViewModelBase
: InferenceTabViewModelBase,
IImageGalleryComponent
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
[JsonPropertyName("ImageGallery")]
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
[JsonIgnore]
public ImageFolderCardViewModel ImageFolderCardViewModel { get; }
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
/// <inheritdoc />
protected InferenceGenerationViewModelBase(
ServiceManager<ViewModelBase> vmFactory,
IInferenceClientManager inferenceClientManager,
INotificationService notificationService
)
{
this.notificationService = notificationService;
ClientManager = inferenceClientManager;
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
ImageFolderCardViewModel = vmFactory.Get<ImageFolderCardViewModel>();
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
/// <summary>
/// Builds the image generation prompt
/// </summary>
protected virtual void BuildPrompt(BuildPromptEventArgs args) { }
/// <summary>
/// Runs a generation task
/// </summary>
/// <exception cref="InvalidOperationException">Thrown if args.Parameters or args.Project are null</exception>
protected async Task RunGeneration(
ImageGenerationEventArgs args,
CancellationToken cancellationToken
)
{
var client = args.Client;
var nodes = args.Nodes;
// Checks
if (args.Parameters is null)
throw new InvalidOperationException("Parameters is null");
if (args.Project is null)
throw new InvalidOperationException("Project is null");
if (args.OutputNodeNames.Count == 0)
throw new InvalidOperationException("OutputNodeNames is empty");
if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null");
// Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
{
// Register to interrupt if user cancels
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
promptTask.Id,
cancellationToken
);
ImageGalleryCardViewModel.ImageSources.Clear();
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
await ProcessOutputImages(images, args);
}
finally
{
// Disconnect progress handler
client.PreviewImageReceived -= OnPreviewImageReceived;
// Clear progress
OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
// Cleanup tasks
promptTask?.Dispose();
}
}
/// <summary>
/// Handles image output metadata for generation runs
/// </summary>
private async Task ProcessOutputImages(
IEnumerable<ComfyImage> images,
ImageGenerationEventArgs args
)
{
// Write metadata to images
var outputImages = new List<ImageSource>();
foreach (
var filePath in images.Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
)
{
var bytesWithMetadata = PngDataHelper.AddMetadata(
await filePath.ReadAllBytesAsync(),
args.Parameters!,
args.Project!
);
await using (var outputStream = filePath.Info.OpenWrite())
{
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
}
outputImages.Add(new ImageSource(filePath));
EventManager.Instance.OnImageFileAdded(filePath);
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImages
.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead()))
.ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
args.Parameters!,
args.Project!
);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = args.Client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using (var fileStream = gridPath.Info.OpenWrite())
{
await fileStream.WriteAsync(gridBytesWithMetadata);
}
// Insert to start of images
var gridImage = new ImageSource(gridPath);
// Preload
await gridImage.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(gridImage);
EventManager.Instance.OnImageFileAdded(gridPath);
}
// Add rest of images
foreach (var img in outputImages)
{
// Preload
await img.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(img);
}
}
/// <summary>
/// Implementation for Generate Image
/// </summary>
protected virtual Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
{
return Task.CompletedTask;
}
/// <summary>
/// Command for the Generate Image button
/// </summary>
/// <param name="options">Optional overrides (side buttons)</param>
/// <param name="cancellationToken">Cancellation token</param>
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
GenerateFlags options = default,
CancellationToken cancellationToken = default
)
{
try
{
var overrides = GenerateOverrides.FromFlags(options);
await GenerateImageImpl(overrides, cancellationToken);
}
catch (OperationCanceledException)
{
Logger.Debug($"Image Generation Canceled");
}
}
/// <summary>
/// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery.
/// </summary>
protected virtual void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
/// <summary>
/// Handles the progress update received event from the websocket.
/// Updates the progress view model.
/// </summary>
protected virtual void OnProgressUpdateReceived(
object? sender,
ComfyProgressUpdateEventArgs args
)
{
Dispatcher.UIThread.Post(() =>
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
public class ImageGenerationEventArgs : EventArgs
{
public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList<string> OutputNodeNames { get; init; }
public GenerationParameters? Parameters { get; set; }
public InferenceProjectDocument? Project { get; set; }
}
public class BuildPromptEventArgs : EventArgs
{
public ComfyNodeBuilder Builder { get; } = new();
public GenerateOverrides Overrides { get; set; } = new();
}
[Flags]
public enum GenerateFlags
{
None = 0,
HiresFixEnable = 1 << 1,
HiresFixDisable = 1 << 2,
UseCurrentSeed = 1 << 3,
UseRandomSeed = 1 << 4
}
public class GenerateOverrides
{
public bool? IsHiresFixEnabled { get; set; }
public bool? UseCurrentSeed { get; set; }
public static GenerateOverrides FromFlags(GenerateFlags flags)
{
var overrides = new GenerateOverrides()
{
IsHiresFixEnabled = flags.HasFlag(GenerateFlags.HiresFixEnable)
? true
: flags.HasFlag(GenerateFlags.HiresFixDisable)
? false
: null,
UseCurrentSeed = flags.HasFlag(GenerateFlags.UseCurrentSeed)
? true
: flags.HasFlag(GenerateFlags.UseRandomSeed)
? false
: null
};
return overrides;
}
}
}

298
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -1,35 +1,21 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using DynamicData.Binding;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView;
@ -38,19 +24,12 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), persistent: true)]
public partial class InferenceTextToImageViewModel
: InferenceTabViewModelBase,
IImageGalleryComponent
public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
private readonly IModelIndexService modelIndexService;
private readonly IImageIndexService imageIndexService;
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
[JsonIgnore]
public StackCardViewModel StackCardViewModel { get; }
@ -61,12 +40,6 @@ public partial class InferenceTextToImageViewModel
[JsonPropertyName("Sampler")]
public SamplerCardViewModel SamplerCardViewModel { get; }
[JsonPropertyName("ImageGallery")]
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
[JsonPropertyName("ImageFolder")]
public ImageFolderCardViewModel ImageFolderCardViewModel { get; }
[JsonPropertyName("Prompt")]
public PromptCardViewModel PromptCardViewModel { get; }
@ -97,26 +70,16 @@ public partial class InferenceTextToImageViewModel
set => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled = value;
}
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
[ObservableProperty]
[property: JsonIgnore]
private string? outputImageSource;
public InferenceTextToImageViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ServiceManager<ViewModelBase> vmFactory,
IModelIndexService modelIndexService,
IImageIndexService imageIndexService
IModelIndexService modelIndexService
)
: base(vmFactory, inferenceClientManager, notificationService)
{
this.notificationService = notificationService;
this.vmFactory = vmFactory;
this.modelIndexService = modelIndexService;
this.imageIndexService = imageIndexService;
ClientManager = inferenceClientManager;
// Get sub view models from service manager
@ -133,8 +96,6 @@ public partial class InferenceTextToImageViewModel
samplerCard.IsSchedulerSelectionEnabled = true;
});
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
ImageFolderCardViewModel = vmFactory.Get<ImageFolderCardViewModel>();
PromptCardViewModel = vmFactory.Get<PromptCardViewModel>();
HiresSamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
@ -181,17 +142,16 @@ public partial class InferenceTextToImageViewModel
SamplerCardViewModel.IsRefinerStepsEnabled =
e.Sender is { IsRefinerSelectionEnabled: true, SelectedRefiner: not null };
});
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
private (NodeDictionary prompt, string[] outputs) BuildPrompt(
GenerateOverrides? overrides = null
)
/// <inheritdoc />
protected override void BuildPrompt(BuildPromptEventArgs args)
{
using var _ = new CodeTimer();
base.BuildPrompt(args);
var builder = new ComfyNodeBuilder();
using var _ = CodeTimer.StartDebug();
var builder = args.Builder;
var nodes = builder.Nodes;
// Setup empty latent
@ -232,7 +192,7 @@ public partial class InferenceTextToImageViewModel
}
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
if (overrides?.IsHiresFixEnabled ?? IsHiresFixEnabled)
if (args.Overrides.IsHiresFixEnabled ?? IsHiresFixEnabled)
{
// Requested upscale to this size
var hiresSize = builder.Connections.GetScaledLatentSize(
@ -309,35 +269,12 @@ public partial class InferenceTextToImageViewModel
// Set as the image output
builder.Connections.Image = postUpscaleGroup.Output;
}
// Output
var outputName = builder.SetupOutputImage();
return (builder.ToNodeDictionary(), new[] { outputName });
}
private void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args)
{
Dispatcher.UIThread.Post(() =>
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
private async Task GenerateImageImpl(
GenerateOverrides? overrides = null,
CancellationToken cancellationToken = default
/// <inheritdoc />
protected override async Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
{
// Validate the prompts
@ -359,192 +296,33 @@ public partial class InferenceTextToImageViewModel
seedCard.GenerateNewSeed();
}
var client = ClientManager.Client;
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides };
BuildPrompt(buildPromptArgs);
var (nodes, outputNodeNames) = BuildPrompt(overrides);
var generationInfo = new GenerationParameters
{
Seed = (ulong)seedCard.Seed,
Steps = SamplerCardViewModel.Steps,
CfgScale = SamplerCardViewModel.CfgScale,
Sampler = SamplerCardViewModel.SelectedSampler?.Name,
ModelName = ModelCardViewModel.SelectedModelName,
// TODO: ModelHash
PositivePrompt = PromptCardViewModel.PromptDocument.Text,
NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text
};
var smproj = InferenceProjectDocument.FromLoadable(this);
// Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
var generationArgs = new ImageGenerationEventArgs
{
// Register to interrupt if user cancels
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
Client = ClientManager.Client,
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
Parameters = new GenerationParameters
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
promptTask.Id,
cancellationToken
);
ImageGalleryCardViewModel.ImageSources.Clear();
if (!imageOutputs.TryGetValue(outputNodeNames[0], out var images) || images is null)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
List<ImageSource> outputImages;
// Use local file path if available, otherwise use remote URL
if (client.OutputImagesDir is { } outputPath)
{
outputImages = new List<ImageSource>();
foreach (var image in images)
{
var filePath = image.ToFilePath(outputPath);
var bytesWithMetadata = PngDataHelper.AddMetadata(
await filePath.ReadAllBytesAsync(),
generationInfo,
smproj
);
/*await using (var readStream = filePath.Info.OpenWrite())
{
using (var reader = new BinaryReader(readStream))
{
}
}*/
await using (var outputStream = filePath.Info.OpenWrite())
{
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
}
outputImages.Add(new ImageSource(filePath));
imageIndexService.OnImageAdded(filePath);
}
}
else
{
outputImages = images!
.Select(i => new ImageSource(i.ToUri(client.BaseAddress)))
.ToList();
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImages
.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead()))
.ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
generationInfo,
smproj
);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using (var fileStream = gridPath.Info.OpenWrite())
{
await fileStream.WriteAsync(gridBytesWithMetadata, cancellationToken);
}
// Insert to start of images
var gridImage = new ImageSource(gridPath);
// Preload
await gridImage.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(gridImage);
imageIndexService.OnImageAdded(gridPath);
}
// Add rest of images
foreach (var img in outputImages)
{
// Preload
await img.GetBitmapAsync();
ImageGalleryCardViewModel.ImageSources.Add(img);
}
}
finally
{
// Disconnect progress handler
OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
promptTask?.Dispose();
client.PreviewImageReceived -= OnPreviewImageReceived;
}
}
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
string? options = null,
CancellationToken cancellationToken = default
)
{
try
{
var overrides = new GenerateOverrides
{
IsHiresFixEnabled = options?.Contains("hires_fix"),
UseCurrentSeed = options?.Contains("current_seed")
};
await GenerateImageImpl(overrides, cancellationToken);
}
catch (OperationCanceledException e)
{
Logger.Debug($"[Image Generation Canceled] {e.Message}");
}
}
Seed = (ulong)seedCard.Seed,
Steps = SamplerCardViewModel.Steps,
CfgScale = SamplerCardViewModel.CfgScale,
Sampler = SamplerCardViewModel.SelectedSampler?.Name,
ModelName = ModelCardViewModel.SelectedModelName,
ModelHash = ModelCardViewModel
.SelectedModel
?.Local
?.ConnectedModelInfo
?.Hashes
.SHA256,
PositivePrompt = PromptCardViewModel.PromptDocument.Text,
NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text
},
Project = InferenceProjectDocument.FromLoadable(this)
};
internal class GenerateOverrides
{
public bool? IsHiresFixEnabled { get; set; }
public bool? UseCurrentSeed { get; set; }
await RunGeneration(generationArgs, cancellationToken);
}
}

19
StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs

@ -70,6 +70,8 @@ public partial class InferenceViewModel : PageViewModelBase
public IInferenceClientManager ClientManager { get; }
public SharedState SharedState { get; }
public ObservableCollection<InferenceTabViewModelBase> Tabs { get; } = new();
[ObservableProperty]
@ -93,7 +95,8 @@ public partial class InferenceViewModel : PageViewModelBase
IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
IModelIndexService modelIndexService,
ILiteDbContext liteDbContext
ILiteDbContext liteDbContext,
SharedState sharedState
)
{
this.vmFactory = vmFactory;
@ -103,6 +106,7 @@ public partial class InferenceViewModel : PageViewModelBase
this.liteDbContext = liteDbContext;
ClientManager = inferenceClientManager;
SharedState = sharedState;
// Keep RunningPackage updated with the current package pair
EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged;
@ -240,8 +244,10 @@ public partial class InferenceViewModel : PageViewModelBase
continue;
}
var projectPath = projectFile.ToString();
var entry = await liteDbContext.InferenceProjects.FindOneAsync(
p => p.FilePath == projectFile.ToString()
p => p.FilePath == projectPath
);
// Create if not found
@ -299,9 +305,14 @@ public partial class InferenceViewModel : PageViewModelBase
/// When the + button on the tab control is clicked, add a new tab.
/// </summary>
[RelayCommand]
private void AddTab()
public void AddTab(InferenceProjectType type = InferenceProjectType.TextToImage)
{
var tab = vmFactory.Get<InferenceTextToImageViewModel>();
if (type.ToViewModelType() is not { } vmType)
{
return;
}
var tab = (InferenceTabViewModelBase)vmFactory.Get(vmType);
Tabs.Add(tab);
// Set as new selected tab

13
StabilityMatrix.Avalonia/Views/InferencePage.axaml

@ -40,7 +40,7 @@
<ui:CommandBarButton Label="Delete" ToolTip.Tip="Delete" />-->
<ui:CommandBarFlyout.SecondaryCommands>
<ui:CommandBarButton
Click="AddTabMenu_TextToImageButton_OnClick"
Click="AddTabMenu_TextToImage_OnClick"
IconSource="FullScreenMaximize"
Label="Text to Image"
ToolTip.Tip="Text to Image" />
@ -54,6 +54,17 @@
IsEnabled="False"
Label="Inpaint"
ToolTip.Tip="Inpaint" />
<ui:CommandBarButton
Click="AddTabMenu_Upscale_OnClick"
IsEnabled="{Binding SharedState.IsDebugMode}"
Label="Upscale"
ToolTip.Tip="Upscale">
<ui:CommandBarButton.IconSource>
<fluentIcons:SymbolIconSource
FontSize="10"
Symbol="ResizeImage"/>
</ui:CommandBarButton.IconSource>
</ui:CommandBarButton>
</ui:CommandBarFlyout.SecondaryCommands>
</ui:CommandBarFlyout>
</controls:UserControlBase.Resources>

10
StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs

@ -6,6 +6,7 @@ using Avalonia.Input;
using Avalonia.Interactivity;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.ViewModels;
namespace StabilityMatrix.Avalonia.Views;
@ -47,8 +48,13 @@ public partial class InferencePage : UserControlBase
addTabFlyout.ShowAt(AddButton);
}
private void AddTabMenu_TextToImageButton_OnClick(object? sender, RoutedEventArgs e)
private void AddTabMenu_TextToImage_OnClick(object? sender, RoutedEventArgs e)
{
(DataContext as InferenceViewModel)!.AddTabCommand.Execute(null);
(DataContext as InferenceViewModel)!.AddTab();
}
private void AddTabMenu_Upscale_OnClick(object? sender, RoutedEventArgs e)
{
(DataContext as InferenceViewModel)!.AddTab(InferenceProjectType.Upscale);
}
}

5
StabilityMatrix.Core/Helper/EventManager.cs

@ -1,5 +1,6 @@
using System.Globalization;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Models.Update;
@ -32,6 +33,8 @@ public class EventManager
public event EventHandler? ModelIndexChanged;
public event EventHandler<FilePath>? ImageFileAdded;
public void OnGlobalProgressChanged(int progress) =>
GlobalProgressChanged?.Invoke(this, progress);
@ -69,4 +72,6 @@ public class EventManager
public void OnCultureChanged(CultureInfo culture) => CultureChanged?.Invoke(this, culture);
public void OnModelIndexChanged() => ModelIndexChanged?.Invoke(this, EventArgs.Empty);
public void OnImageFileAdded(FilePath filePath) => ImageFileAdded?.Invoke(this, filePath);
}

5
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -529,6 +529,11 @@ public class ComfyNodeBuilder
public Size LatentSize { get; set; }
public ImageNodeConnection? Image { get; set; }
public Size ImageSize { get; set; }
public List<NamedComfyNode> OutputNodes { get; } = new();
public IEnumerable<string> OutputNodeNames => OutputNodes.Select(n => n.Name);
/// <summary>
/// Gets the latent size scaled by a given factor

2
StabilityMatrix.Core/Services/IImageIndexService.cs

@ -21,8 +21,6 @@ public interface IImageIndexService
Task RefreshIndex(IndexCollection<LocalImageFile, string> indexCollection);
void OnImageAdded(FilePath filePath);
/// <summary>
/// Refreshes the index of local images in the background
/// </summary>

10
StabilityMatrix.Core/Services/ImageIndexService.cs

@ -39,12 +39,7 @@ public class ImageIndexService : IImageIndexService
RelativePath = "inference"
};
/*inferenceImagesSource
.Connect()
.DeferUntilLoaded()
.SortBy(file => file.LastModifiedAt, SortDirection.Descending)
.Bind(InferenceImages)
.Subscribe();*/
EventManager.Instance.ImageFileAdded += OnImageFileAdded;
}
/// <inheritdoc />
@ -125,8 +120,7 @@ public class ImageIndexService : IImageIndexService
);
}
/// <inheritdoc />
public void OnImageAdded(FilePath filePath)
private void OnImageFileAdded(object? sender, FilePath filePath)
{
var fullPath = settingsManager.ImagesDirectory.JoinDir(InferenceImages.RelativePath!);

Loading…
Cancel
Save