Browse Source

Add ComfyTask, interrupts, improve progress reports

pull/165/head
Ionite 1 year ago
parent
commit
e5e8706cde
No known key found for this signature in database
  1. 52
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  2. 14
      StabilityMatrix.Avalonia/Views/InferenceTextToImageView.axaml
  3. 45
      StabilityMatrix.Core/Inference/ComfyClient.cs
  4. 7
      StabilityMatrix.Core/Inference/ComfyProgressUpdateEventArgs.cs
  5. 33
      StabilityMatrix.Core/Inference/ComfyTask.cs

52
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -8,17 +8,22 @@ using System.Text.Json.Nodes;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Media.Imaging; using Avalonia.Media.Imaging;
using AvaloniaEdit.Document; using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using NLog; using NLog;
using Refit;
using SkiaSharp; using SkiaSharp;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration #pragma warning disable CS0657 // Not a valid attribute location for this declaration
@ -107,6 +112,8 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Batch Size // Batch Size
vmFactory.Get<BatchSizeCardViewModel>(), vmFactory.Get<BatchSizeCardViewModel>(),
}); });
GenerateImageCommand.WithNotificationErrorHandler(notificationService);
} }
private Dictionary<string, ComfyNode> GetCurrentPrompt() private Dictionary<string, ComfyNode> GetCurrentPrompt()
@ -236,11 +243,14 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
return prompt; return prompt;
} }
private void OnProgressUpdateReceived(object? sender, ComfyWebSocketProgressData args) private void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args)
{ {
OutputProgress.Value = args.Value; OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Max; OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false; OutputProgress.IsIndeterminate = false;
OutputProgress.Text = $"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
} }
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args) private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
@ -274,21 +284,40 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
var nodes = GetCurrentPrompt(); var nodes = GetCurrentPrompt();
// Connect progress handler // Connect progress handler
client.ProgressUpdateReceived += OnProgressUpdateReceived; // client.ProgressUpdateReceived += OnProgressUpdateReceived;
client.PreviewImageReceived += OnPreviewImageReceived; client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try try
{ {
var (response, promptTask) = await client.QueuePromptAsync(nodes, cancellationToken); // Register to interrupt if user cancels
Logger.Info(response); cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client.InterruptPromptAsync(new CancellationTokenSource(5000).Token).SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Wait for prompt to finish // Wait for prompt to finish
await promptTask.WaitAsync(cancellationToken); await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {response.PromptId} finished"); Logger.Trace($"Prompt task {promptTask.Id} finished");
// Get output images // Get output images
var outputs = await client.GetImagesForExecutedPromptAsync( var outputs = await client.GetImagesForExecutedPromptAsync(
response.PromptId, promptTask.Id,
cancellationToken cancellationToken
); );
@ -339,15 +368,18 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
{ {
// Disconnect progress handler // Disconnect progress handler
OutputProgress.Value = 0; OutputProgress.Value = 0;
OutputProgress.Text = "";
ImageGalleryCardViewModel.PreviewImage?.Dispose(); ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null; ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false; ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
client.ProgressUpdateReceived -= OnProgressUpdateReceived;
// client.ProgressUpdateReceived -= OnProgressUpdateReceived;
promptTask?.Dispose();
client.PreviewImageReceived -= OnPreviewImageReceived; client.PreviewImageReceived -= OnPreviewImageReceived;
} }
} }
[RelayCommand(IncludeCancelCommand = true)] [RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(CancellationToken cancellationToken = default) private async Task GenerateImage(CancellationToken cancellationToken = default)
{ {
try try

14
StabilityMatrix.Avalonia/Views/InferenceTextToImageView.axaml

@ -201,13 +201,21 @@
<Grid <Grid
x:CompileBindings="False" x:CompileBindings="False"
DataContext="{Binding ElementName=Dock, Path=DataContext}"> DataContext="{Binding ElementName=Dock, Path=DataContext}">
<ProgressBar <StackPanel
Margin="2,1,2,4"
VerticalAlignment="Top"
DataContext="{Binding OutputProgress}" DataContext="{Binding OutputProgress}"
Margin="2,1,2,4"
Spacing="4"
VerticalAlignment="Top">
<ProgressBar
IsVisible="{Binding IsProgressVisible}" IsVisible="{Binding IsProgressVisible}"
Maximum="{Binding Maximum}" Maximum="{Binding Maximum}"
Value="{Binding Value}" /> Value="{Binding Value}" />
<TextBlock
IsVisible="{Binding IsTextVisible}"
TextAlignment="Center"
Text="{Binding Text}"/>
</StackPanel>
<controls:ImageGalleryCard <controls:ImageGalleryCard
Grid.Row="0" Grid.Row="0"
DataContext="{Binding ImageGalleryCardViewModel}" /> DataContext="{Binding ImageGalleryCardViewModel}" />

45
StabilityMatrix.Core/Inference/ComfyClient.cs

@ -31,7 +31,12 @@ public class ComfyClient : InferenceClientBase
/// <summary> /// <summary>
/// Dictionary of ongoing prompt execution tasks /// Dictionary of ongoing prompt execution tasks
/// </summary> /// </summary>
public ConcurrentDictionary<string, TaskCompletionSource> PromptTasks { get; } = new(); public ConcurrentDictionary<string, ComfyTask> PromptTasks { get; } = new();
/// <summary>
/// Current running prompt task
/// </summary>
private ComfyTask? currentPromptTask;
/// <summary> /// <summary>
/// Event raised when a progress update is received from the server /// Event raised when a progress update is received from the server
@ -138,13 +143,23 @@ public class ComfyClient : InferenceClientBase
{ {
if (PromptTasks.TryRemove(executingData.PromptId, out var task)) if (PromptTasks.TryRemove(executingData.PromptId, out var task))
{ {
task.RunningNode = null;
task.SetResult(); task.SetResult();
currentPromptTask = null;
} }
else else
{ {
Logger.Warn($"Could not find task for prompt {executingData.PromptId}, skipping"); Logger.Warn($"Could not find task for prompt {executingData.PromptId}, skipping");
} }
} }
// Otherwise set the task's active node to the one received
else
{
if (PromptTasks.TryGetValue(executingData.PromptId, out var task))
{
task.RunningNode = executingData.Node;
}
}
ExecutingUpdateReceived?.Invoke(this, executingData); ExecutingUpdateReceived?.Invoke(this, executingData);
} }
@ -168,6 +183,9 @@ public class ComfyClient : InferenceClientBase
return; return;
} }
// Set for the current prompt task
currentPromptTask?.OnProgressUpdate(progressData);
ProgressUpdateReceived?.Invoke(this, progressData); ProgressUpdateReceived?.Invoke(this, progressData);
} }
else else
@ -221,7 +239,7 @@ public class ComfyClient : InferenceClientBase
.ConfigureAwait(false); .ConfigureAwait(false);
} }
public async Task<(ComfyPromptResponse, Task)> QueuePromptAsync( public async Task<ComfyTask> QueuePromptAsync(
Dictionary<string, ComfyNode> nodes, Dictionary<string, ComfyNode> nodes,
CancellationToken cancellationToken = default CancellationToken cancellationToken = default
) )
@ -229,11 +247,26 @@ public class ComfyClient : InferenceClientBase
var request = new ComfyPromptRequest { ClientId = ClientId, Prompt = nodes }; var request = new ComfyPromptRequest { ClientId = ClientId, Prompt = nodes };
var result = await comfyApi.PostPrompt(request, cancellationToken).ConfigureAwait(false); var result = await comfyApi.PostPrompt(request, cancellationToken).ConfigureAwait(false);
// Add task to dictionary // Add task to dictionary and set it as the current task
var tcs = new TaskCompletionSource(); var task = new ComfyTask(result.PromptId);
PromptTasks[result.PromptId] = tcs; PromptTasks[result.PromptId] = task;
currentPromptTask = task;
return task;
}
return (result, tcs.Task); public async Task InterruptPromptAsync(CancellationToken cancellationToken = default)
{
await comfyApi.PostInterrupt(cancellationToken).ConfigureAwait(false);
// Set the current task to null, and remove it from the dictionary
if (currentPromptTask is { } task)
{
PromptTasks.TryRemove(task.Id, out _);
task.TrySetCanceled(cancellationToken);
task.Dispose();
currentPromptTask = null;
}
} }
public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync( public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync(

7
StabilityMatrix.Core/Inference/ComfyProgressUpdateEventArgs.cs

@ -0,0 +1,7 @@
namespace StabilityMatrix.Core.Inference;
public readonly record struct ComfyProgressUpdateEventArgs(
int Value,
int Maximum,
string? TaskId,
string? RunningNode);

33
StabilityMatrix.Core/Inference/ComfyTask.cs

@ -0,0 +1,33 @@
using System.Reactive;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Core.Inference;
public class ComfyTask : TaskCompletionSource, IDisposable
{
public string Id { get; set; }
public string? RunningNode { get; set; }
public EventHandler<ComfyProgressUpdateEventArgs>? ProgressUpdate;
public ComfyTask(string id)
{
Id = id;
}
/// <summary>
/// Handler for progress updates
/// </summary>
public void OnProgressUpdate(ComfyWebSocketProgressData update)
{
ProgressUpdate?.Invoke(this, new ComfyProgressUpdateEventArgs(update.Value, update.Max, Id, RunningNode));
}
/// <inheritdoc />
public void Dispose()
{
ProgressUpdate = null;
GC.SuppressFinalize(this);
}
}
Loading…
Cancel
Save