Browse Source

Add ComfyTask, interrupts, improve progress reports

pull/165/head
Ionite 1 year ago
parent
commit
e5e8706cde
No known key found for this signature in database
  1. 52
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  2. 20
      StabilityMatrix.Avalonia/Views/InferenceTextToImageView.axaml
  3. 47
      StabilityMatrix.Core/Inference/ComfyClient.cs
  4. 7
      StabilityMatrix.Core/Inference/ComfyProgressUpdateEventArgs.cs
  5. 33
      StabilityMatrix.Core/Inference/ComfyTask.cs

52
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -8,17 +8,22 @@ using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Media.Imaging;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
@ -107,6 +112,8 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Batch Size
vmFactory.Get<BatchSizeCardViewModel>(),
});
GenerateImageCommand.WithNotificationErrorHandler(notificationService);
}
private Dictionary<string, ComfyNode> GetCurrentPrompt()
@ -236,11 +243,14 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
return prompt;
}
private void OnProgressUpdateReceived(object? sender, ComfyWebSocketProgressData args)
private void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args)
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Max;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text = $"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
}
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
@ -274,21 +284,40 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
var nodes = GetCurrentPrompt();
// Connect progress handler
client.ProgressUpdateReceived += OnProgressUpdateReceived;
// client.ProgressUpdateReceived += OnProgressUpdateReceived;
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
{
var (response, promptTask) = await client.QueuePromptAsync(nodes, cancellationToken);
Logger.Info(response);
// Register to interrupt if user cancels
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client.InterruptPromptAsync(new CancellationTokenSource(5000).Token).SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish
await promptTask.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {response.PromptId} finished");
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images
var outputs = await client.GetImagesForExecutedPromptAsync(
response.PromptId,
promptTask.Id,
cancellationToken
);
@ -339,15 +368,18 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
{
// Disconnect progress handler
OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
client.ProgressUpdateReceived -= OnProgressUpdateReceived;
// client.ProgressUpdateReceived -= OnProgressUpdateReceived;
promptTask?.Dispose();
client.PreviewImageReceived -= OnPreviewImageReceived;
}
}
[RelayCommand(IncludeCancelCommand = true)]
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(CancellationToken cancellationToken = default)
{
try

20
StabilityMatrix.Avalonia/Views/InferenceTextToImageView.axaml

@ -201,13 +201,21 @@
<Grid
x:CompileBindings="False"
DataContext="{Binding ElementName=Dock, Path=DataContext}">
<ProgressBar
Margin="2,1,2,4"
VerticalAlignment="Top"
<StackPanel
DataContext="{Binding OutputProgress}"
IsVisible="{Binding IsProgressVisible}"
Maximum="{Binding Maximum}"
Value="{Binding Value}" />
Margin="2,1,2,4"
Spacing="4"
VerticalAlignment="Top">
<ProgressBar
IsVisible="{Binding IsProgressVisible}"
Maximum="{Binding Maximum}"
Value="{Binding Value}" />
<TextBlock
IsVisible="{Binding IsTextVisible}"
TextAlignment="Center"
Text="{Binding Text}"/>
</StackPanel>
<controls:ImageGalleryCard
Grid.Row="0"
DataContext="{Binding ImageGalleryCardViewModel}" />

47
StabilityMatrix.Core/Inference/ComfyClient.cs

@ -31,7 +31,12 @@ public class ComfyClient : InferenceClientBase
/// <summary>
/// Dictionary of ongoing prompt execution tasks
/// </summary>
public ConcurrentDictionary<string, TaskCompletionSource> PromptTasks { get; } = new();
public ConcurrentDictionary<string, ComfyTask> PromptTasks { get; } = new();
/// <summary>
/// Current running prompt task
/// </summary>
private ComfyTask? currentPromptTask;
/// <summary>
/// Event raised when a progress update is received from the server
@ -138,13 +143,23 @@ public class ComfyClient : InferenceClientBase
{
if (PromptTasks.TryRemove(executingData.PromptId, out var task))
{
task.RunningNode = null;
task.SetResult();
currentPromptTask = null;
}
else
{
Logger.Warn($"Could not find task for prompt {executingData.PromptId}, skipping");
}
}
// Otherwise set the task's active node to the one received
else
{
if (PromptTasks.TryGetValue(executingData.PromptId, out var task))
{
task.RunningNode = executingData.Node;
}
}
ExecutingUpdateReceived?.Invoke(this, executingData);
}
@ -167,7 +182,10 @@ public class ComfyClient : InferenceClientBase
Logger.Warn($"Could not parse progress data {json.Data}, skipping");
return;
}
// Set for the current prompt task
currentPromptTask?.OnProgressUpdate(progressData);
ProgressUpdateReceived?.Invoke(this, progressData);
}
else
@ -221,7 +239,7 @@ public class ComfyClient : InferenceClientBase
.ConfigureAwait(false);
}
public async Task<(ComfyPromptResponse, Task)> QueuePromptAsync(
public async Task<ComfyTask> QueuePromptAsync(
Dictionary<string, ComfyNode> nodes,
CancellationToken cancellationToken = default
)
@ -229,11 +247,26 @@ public class ComfyClient : InferenceClientBase
var request = new ComfyPromptRequest { ClientId = ClientId, Prompt = nodes };
var result = await comfyApi.PostPrompt(request, cancellationToken).ConfigureAwait(false);
// Add task to dictionary
var tcs = new TaskCompletionSource();
PromptTasks[result.PromptId] = tcs;
// Add task to dictionary and set it as the current task
var task = new ComfyTask(result.PromptId);
PromptTasks[result.PromptId] = task;
currentPromptTask = task;
return (result, tcs.Task);
return task;
}
public async Task InterruptPromptAsync(CancellationToken cancellationToken = default)
{
await comfyApi.PostInterrupt(cancellationToken).ConfigureAwait(false);
// Set the current task to null, and remove it from the dictionary
if (currentPromptTask is { } task)
{
PromptTasks.TryRemove(task.Id, out _);
task.TrySetCanceled(cancellationToken);
task.Dispose();
currentPromptTask = null;
}
}
public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync(

7
StabilityMatrix.Core/Inference/ComfyProgressUpdateEventArgs.cs

@ -0,0 +1,7 @@
namespace StabilityMatrix.Core.Inference;
public readonly record struct ComfyProgressUpdateEventArgs(
int Value,
int Maximum,
string? TaskId,
string? RunningNode);

33
StabilityMatrix.Core/Inference/ComfyTask.cs

@ -0,0 +1,33 @@
using System.Reactive;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Core.Inference;
public class ComfyTask : TaskCompletionSource, IDisposable
{
public string Id { get; set; }
public string? RunningNode { get; set; }
public EventHandler<ComfyProgressUpdateEventArgs>? ProgressUpdate;
public ComfyTask(string id)
{
Id = id;
}
/// <summary>
/// Handler for progress updates
/// </summary>
public void OnProgressUpdate(ComfyWebSocketProgressData update)
{
ProgressUpdate?.Invoke(this, new ComfyProgressUpdateEventArgs(update.Value, update.Max, Id, RunningNode));
}
/// <inheritdoc />
public void Dispose()
{
ProgressUpdate = null;
GC.SuppressFinalize(this);
}
}
Loading…
Cancel
Save