Browse Source

Report node exception websocket events in dialog

pull/333/head
Ionite 1 year ago
parent
commit
efb7c704df
No known key found for this signature in database
  1. 19
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  2. 55
      StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs
  3. 9
      StabilityMatrix.Core/Exceptions/ComfyNodeException.cs
  4. 54
      StabilityMatrix.Core/Inference/ComfyClient.cs
  5. 9
      StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs
  6. 15
      StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs
  7. 11
      StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutionErrorData.cs

19
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -257,8 +257,23 @@ public abstract partial class InferenceGenerationViewModelBase
.SafeFireAndForget(); .SafeFireAndForget();
// Wait for prompt to finish // Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken); try
Logger.Debug($"Prompt task {promptTask.Id} finished"); {
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Debug($"Prompt task {promptTask.Id} finished");
}
catch (ComfyNodeException e)
{
Logger.Warn(e, "Comfy node exception while queuing prompt");
await DialogHelper
.CreateJsonDialog(
e.JsonData,
"Comfy Error",
"Node execution encountered an error"
)
.ShowAsync();
return;
}
// Get output images // Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync( var imageOutputs = await client.GetImagesForExecutedPromptAsync(

55
StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs

@ -1,14 +1,44 @@
using System.Text.Json; using System.Reflection;
using System.Runtime.Serialization;
using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
namespace StabilityMatrix.Core.Converters.Json; namespace StabilityMatrix.Core.Converters.Json;
public class DefaultUnknownEnumConverter<T> : JsonConverter<T> where T : Enum public class DefaultUnknownEnumConverter<T> : JsonConverter<T>
where T : Enum
{ {
public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) // Get EnumMember attribute value
private Dictionary<string, T>? _enumMemberValues;
private IReadOnlyDictionary<string, T> EnumMemberValues =>
_enumMemberValues ??= typeof(T)
.GetFields()
.Where(field => field.IsStatic)
.Select(
field =>
new
{
Field = field,
Attribute = field
.GetCustomAttributes<EnumMemberAttribute>(false)
.FirstOrDefault()
}
)
.Where(field => field.Attribute != null)
.ToDictionary(
field => field.Attribute!.Value!.ToString(),
field => (T)field.Field.GetValue(null)!
);
public override T Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options
)
{ {
if (reader.TokenType != JsonTokenType.String) if (reader.TokenType != JsonTokenType.String)
{ {
throw new JsonException(); throw new JsonException();
} }
@ -16,15 +46,24 @@ public class DefaultUnknownEnumConverter<T> : JsonConverter<T> where T : Enum
var enumText = reader.GetString()?.Replace(" ", "_"); var enumText = reader.GetString()?.Replace(" ", "_");
if (Enum.TryParse(typeof(T), enumText, true, out var result)) if (Enum.TryParse(typeof(T), enumText, true, out var result))
{ {
return (T) result!; return (T)result!;
}
// Try using enum member values
if (enumText != null)
{
if (EnumMemberValues.TryGetValue(enumText, out var enumMemberResult))
{
return enumMemberResult;
}
} }
// Unknown value handling // Unknown value handling
if (Enum.TryParse(typeof(T), "Unknown", true, out var unknownResult)) if (Enum.TryParse(typeof(T), "Unknown", true, out var unknownResult))
{ {
return (T) unknownResult!; return (T)unknownResult!;
} }
throw new JsonException($"Unable to parse '{enumText}' to enum '{typeof(T)}'."); throw new JsonException($"Unable to parse '{enumText}' to enum '{typeof(T)}'.");
} }

9
StabilityMatrix.Core/Exceptions/ComfyNodeException.cs

@ -0,0 +1,9 @@
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Core.Exceptions;
public class ComfyNodeException : Exception
{
public required ComfyWebSocketExecutionErrorData ErrorData { get; init; }
public required string JsonData { get; init; }
}

54
StabilityMatrix.Core/Inference/ComfyClient.cs

@ -6,6 +6,7 @@ using NLog;
using Polly.Contrib.WaitAndRetry; using Polly.Contrib.WaitAndRetry;
using Refit; using Refit;
using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
@ -13,6 +14,7 @@ using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.FileInterfaces;
using Websocket.Client; using Websocket.Client;
using Websocket.Client.Exceptions; using Websocket.Client.Exceptions;
using Yoh.Text.Json.NamingPolicies;
namespace StabilityMatrix.Core.Inference; namespace StabilityMatrix.Core.Inference;
@ -24,6 +26,9 @@ public class ComfyClient : InferenceClientBase
private readonly IComfyApi comfyApi; private readonly IComfyApi comfyApi;
private bool isDisposed; private bool isDisposed;
private JsonSerializerOptions jsonSerializerOptions =
new() { PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower, };
// ReSharper disable once MemberCanBePrivate.Global // ReSharper disable once MemberCanBePrivate.Global
public string ClientId { get; } = Guid.NewGuid().ToString(); public string ClientId { get; } = Guid.NewGuid().ToString();
@ -121,7 +126,7 @@ public class ComfyClient : InferenceClientBase
ComfyWebSocketResponse? json; ComfyWebSocketResponse? json;
try try
{ {
json = JsonSerializer.Deserialize<ComfyWebSocketResponse>(text); json = JsonSerializer.Deserialize<ComfyWebSocketResponse>(text, jsonSerializerOptions);
} }
catch (JsonException e) catch (JsonException e)
{ {
@ -139,7 +144,9 @@ public class ComfyClient : InferenceClientBase
if (json.Type == ComfyWebSocketResponseType.Executing) if (json.Type == ComfyWebSocketResponseType.Executing)
{ {
var executingData = json.GetDataAsType<ComfyWebSocketExecutingData>(); var executingData = json.GetDataAsType<ComfyWebSocketExecutingData>(
jsonSerializerOptions
);
if (executingData?.PromptId is null) if (executingData?.PromptId is null)
{ {
Logger.Warn($"Could not parse executing data {json.Data}, skipping"); Logger.Warn($"Could not parse executing data {json.Data}, skipping");
@ -176,7 +183,7 @@ public class ComfyClient : InferenceClientBase
} }
else if (json.Type == ComfyWebSocketResponseType.Status) else if (json.Type == ComfyWebSocketResponseType.Status)
{ {
var statusData = json.GetDataAsType<ComfyWebSocketStatusData>(); var statusData = json.GetDataAsType<ComfyWebSocketStatusData>(jsonSerializerOptions);
if (statusData is null) if (statusData is null)
{ {
Logger.Warn($"Could not parse status data {json.Data}, skipping"); Logger.Warn($"Could not parse status data {json.Data}, skipping");
@ -187,7 +194,9 @@ public class ComfyClient : InferenceClientBase
} }
else if (json.Type == ComfyWebSocketResponseType.Progress) else if (json.Type == ComfyWebSocketResponseType.Progress)
{ {
var progressData = json.GetDataAsType<ComfyWebSocketProgressData>(); var progressData = json.GetDataAsType<ComfyWebSocketProgressData>(
jsonSerializerOptions
);
if (progressData is null) if (progressData is null)
{ {
Logger.Warn($"Could not parse progress data {json.Data}, skipping"); Logger.Warn($"Could not parse progress data {json.Data}, skipping");
@ -199,6 +208,35 @@ public class ComfyClient : InferenceClientBase
ProgressUpdateReceived?.Invoke(this, progressData); ProgressUpdateReceived?.Invoke(this, progressData);
} }
else if (json.Type == ComfyWebSocketResponseType.ExecutionError)
{
if (
json.GetDataAsType<ComfyWebSocketExecutionErrorData>(jsonSerializerOptions)
is not { } errorData
)
{
Logger.Warn($"Could not parse ExecutionError data {json.Data}, skipping");
return;
}
// Set error status
if (PromptTasks.TryRemove(errorData.PromptId, out var task))
{
task.RunningNode = null;
task.SetException(
new ComfyNodeException
{
ErrorData = errorData,
JsonData = json.Data.ToString()
}
);
currentPromptTask = null;
}
else
{
Logger.Warn($"Could not find task for prompt {errorData.PromptId}, skipping");
}
}
else else
{ {
Logger.Warn($"Unknown message type {json.Type} ({json.Data}), skipping"); Logger.Warn($"Unknown message type {json.Type} ({json.Data}), skipping");
@ -311,7 +349,13 @@ public class ComfyClient : InferenceClientBase
) )
{ {
var streamPart = new StreamPart(image, fileName); var streamPart = new StreamPart(image, fileName);
return comfyApi.PostUploadImage(streamPart, true, "input", "Inference", cancellationToken); return comfyApi.PostUploadImage(
streamPart,
"true",
"input",
"Inference",
cancellationToken
);
} }
public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync( public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync(

9
StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs

@ -9,7 +9,7 @@ public class ComfyWebSocketResponse
{ {
[JsonPropertyName("type")] [JsonPropertyName("type")]
public required ComfyWebSocketResponseType Type { get; set; } public required ComfyWebSocketResponseType Type { get; set; }
/// <summary> /// <summary>
/// Depending on the value of <see cref="Type"/>, /// Depending on the value of <see cref="Type"/>,
/// this property will be one of these types /// this property will be one of these types
@ -21,9 +21,10 @@ public class ComfyWebSocketResponse
/// </summary> /// </summary>
[JsonPropertyName("data")] [JsonPropertyName("data")]
public required JsonObject Data { get; set; } public required JsonObject Data { get; set; }
public T? GetDataAsType<T>() where T : class public T? GetDataAsType<T>(JsonSerializerOptions? options = null)
where T : class
{ {
return Data.Deserialize<T>(); return Data.Deserialize<T>(options);
} }
} }

15
StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs

@ -8,22 +8,25 @@ namespace StabilityMatrix.Core.Models.Api.Comfy;
public enum ComfyWebSocketResponseType public enum ComfyWebSocketResponseType
{ {
Unknown, Unknown,
[EnumMember(Value = "status")] [EnumMember(Value = "status")]
Status, Status,
[EnumMember(Value = "execution_start")] [EnumMember(Value = "execution_start")]
ExecutionStart, ExecutionStart,
[EnumMember(Value = "execution_cached")] [EnumMember(Value = "execution_cached")]
ExecutionCached, ExecutionCached,
[EnumMember(Value = "execution_error")]
ExecutionError,
[EnumMember(Value = "executing")] [EnumMember(Value = "executing")]
Executing, Executing,
[EnumMember(Value = "progress")] [EnumMember(Value = "progress")]
Progress, Progress,
[EnumMember(Value = "executed")] [EnumMember(Value = "executed")]
Executed, Executed,
} }

11
StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutionErrorData.cs

@ -0,0 +1,11 @@
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
public record ComfyWebSocketExecutionErrorData
{
public required string PromptId { get; set; }
public string? NodeId { get; set; }
public string? NodeType { get; set; }
public string? ExceptionMessage { get; set; }
public string? ExceptionType { get; set; }
public string[]? Traceback { get; set; }
}
Loading…
Cancel
Save