diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index f1d0d11c..9eb6a5c9 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -257,8 +257,23 @@ public abstract partial class InferenceGenerationViewModelBase .SafeFireAndForget(); // Wait for prompt to finish - await promptTask.Task.WaitAsync(cancellationToken); - Logger.Debug($"Prompt task {promptTask.Id} finished"); + try + { + await promptTask.Task.WaitAsync(cancellationToken); + Logger.Debug($"Prompt task {promptTask.Id} finished"); + } + catch (ComfyNodeException e) + { + Logger.Warn(e, "Comfy node exception while queuing prompt"); + await DialogHelper + .CreateJsonDialog( + e.JsonData, + "Comfy Error", + "Node execution encountered an error" + ) + .ShowAsync(); + return; + } // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( diff --git a/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs b/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs index 62faa835..dc7e1a8a 100644 --- a/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs +++ b/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs @@ -1,14 +1,44 @@ -using System.Text.Json; +using System.Reflection; +using System.Runtime.Serialization; +using System.Text.Json; using System.Text.Json.Serialization; using StabilityMatrix.Core.Extensions; namespace StabilityMatrix.Core.Converters.Json; -public class DefaultUnknownEnumConverter : JsonConverter where T : Enum +public class DefaultUnknownEnumConverter : JsonConverter + where T : Enum { - public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + // Get EnumMember attribute value + private Dictionary? _enumMemberValues; + + private IReadOnlyDictionary EnumMemberValues => + _enumMemberValues ??= typeof(T) + .GetFields() + .Where(field => field.IsStatic) + .Select( + field => + new + { + Field = field, + Attribute = field + .GetCustomAttributes(false) + .FirstOrDefault() + } + ) + .Where(field => field.Attribute != null) + .ToDictionary( + field => field.Attribute!.Value!.ToString(), + field => (T)field.Field.GetValue(null)! + ); + + public override T Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options + ) { - if (reader.TokenType != JsonTokenType.String) + if (reader.TokenType != JsonTokenType.String) { throw new JsonException(); } @@ -16,15 +46,24 @@ public class DefaultUnknownEnumConverter : JsonConverter where T : Enum var enumText = reader.GetString()?.Replace(" ", "_"); if (Enum.TryParse(typeof(T), enumText, true, out var result)) { - return (T) result!; + return (T)result!; + } + + // Try using enum member values + if (enumText != null) + { + if (EnumMemberValues.TryGetValue(enumText, out var enumMemberResult)) + { + return enumMemberResult; + } } // Unknown value handling - if (Enum.TryParse(typeof(T), "Unknown", true, out var unknownResult)) + if (Enum.TryParse(typeof(T), "Unknown", true, out var unknownResult)) { - return (T) unknownResult!; + return (T)unknownResult!; } - + throw new JsonException($"Unable to parse '{enumText}' to enum '{typeof(T)}'."); } diff --git a/StabilityMatrix.Core/Exceptions/ComfyNodeException.cs b/StabilityMatrix.Core/Exceptions/ComfyNodeException.cs new file mode 100644 index 00000000..78d827ae --- /dev/null +++ b/StabilityMatrix.Core/Exceptions/ComfyNodeException.cs @@ -0,0 +1,9 @@ +using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; + +namespace StabilityMatrix.Core.Exceptions; + +public class ComfyNodeException : Exception +{ + public required ComfyWebSocketExecutionErrorData ErrorData { get; init; } + public required string JsonData { get; init; } +} diff --git a/StabilityMatrix.Core/Inference/ComfyClient.cs b/StabilityMatrix.Core/Inference/ComfyClient.cs index 1b7e09e6..e775c65d 100644 --- a/StabilityMatrix.Core/Inference/ComfyClient.cs +++ b/StabilityMatrix.Core/Inference/ComfyClient.cs @@ -6,6 +6,7 @@ using NLog; using Polly.Contrib.WaitAndRetry; using Refit; using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; @@ -13,6 +14,7 @@ using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.FileInterfaces; using Websocket.Client; using Websocket.Client.Exceptions; +using Yoh.Text.Json.NamingPolicies; namespace StabilityMatrix.Core.Inference; @@ -24,6 +26,9 @@ public class ComfyClient : InferenceClientBase private readonly IComfyApi comfyApi; private bool isDisposed; + private JsonSerializerOptions jsonSerializerOptions = + new() { PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower, }; + // ReSharper disable once MemberCanBePrivate.Global public string ClientId { get; } = Guid.NewGuid().ToString(); @@ -121,7 +126,7 @@ public class ComfyClient : InferenceClientBase ComfyWebSocketResponse? json; try { - json = JsonSerializer.Deserialize(text); + json = JsonSerializer.Deserialize(text, jsonSerializerOptions); } catch (JsonException e) { @@ -139,7 +144,9 @@ public class ComfyClient : InferenceClientBase if (json.Type == ComfyWebSocketResponseType.Executing) { - var executingData = json.GetDataAsType(); + var executingData = json.GetDataAsType( + jsonSerializerOptions + ); if (executingData?.PromptId is null) { Logger.Warn($"Could not parse executing data {json.Data}, skipping"); @@ -176,7 +183,7 @@ public class ComfyClient : InferenceClientBase } else if (json.Type == ComfyWebSocketResponseType.Status) { - var statusData = json.GetDataAsType(); + var statusData = json.GetDataAsType(jsonSerializerOptions); if (statusData is null) { Logger.Warn($"Could not parse status data {json.Data}, skipping"); @@ -187,7 +194,9 @@ public class ComfyClient : InferenceClientBase } else if (json.Type == ComfyWebSocketResponseType.Progress) { - var progressData = json.GetDataAsType(); + var progressData = json.GetDataAsType( + jsonSerializerOptions + ); if (progressData is null) { Logger.Warn($"Could not parse progress data {json.Data}, skipping"); @@ -199,6 +208,35 @@ public class ComfyClient : InferenceClientBase ProgressUpdateReceived?.Invoke(this, progressData); } + else if (json.Type == ComfyWebSocketResponseType.ExecutionError) + { + if ( + json.GetDataAsType(jsonSerializerOptions) + is not { } errorData + ) + { + Logger.Warn($"Could not parse ExecutionError data {json.Data}, skipping"); + return; + } + + // Set error status + if (PromptTasks.TryRemove(errorData.PromptId, out var task)) + { + task.RunningNode = null; + task.SetException( + new ComfyNodeException + { + ErrorData = errorData, + JsonData = json.Data.ToString() + } + ); + currentPromptTask = null; + } + else + { + Logger.Warn($"Could not find task for prompt {errorData.PromptId}, skipping"); + } + } else { Logger.Warn($"Unknown message type {json.Type} ({json.Data}), skipping"); @@ -311,7 +349,13 @@ public class ComfyClient : InferenceClientBase ) { var streamPart = new StreamPart(image, fileName); - return comfyApi.PostUploadImage(streamPart, true, "input", "Inference", cancellationToken); + return comfyApi.PostUploadImage( + streamPart, + "true", + "input", + "Inference", + cancellationToken + ); } public async Task?>> GetImagesForExecutedPromptAsync( diff --git a/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs b/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs index 0a2ed7c9..efd211c1 100644 --- a/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs +++ b/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs @@ -9,7 +9,7 @@ public class ComfyWebSocketResponse { [JsonPropertyName("type")] public required ComfyWebSocketResponseType Type { get; set; } - + /// /// Depending on the value of , /// this property will be one of these types @@ -21,9 +21,10 @@ public class ComfyWebSocketResponse /// [JsonPropertyName("data")] public required JsonObject Data { get; set; } - - public T? GetDataAsType() where T : class + + public T? GetDataAsType(JsonSerializerOptions? options = null) + where T : class { - return Data.Deserialize(); + return Data.Deserialize(options); } } diff --git a/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs b/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs index 17456fbe..7cbe40b8 100644 --- a/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs +++ b/StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs @@ -8,22 +8,25 @@ namespace StabilityMatrix.Core.Models.Api.Comfy; public enum ComfyWebSocketResponseType { Unknown, - + [EnumMember(Value = "status")] Status, - + [EnumMember(Value = "execution_start")] ExecutionStart, - + [EnumMember(Value = "execution_cached")] ExecutionCached, - + + [EnumMember(Value = "execution_error")] + ExecutionError, + [EnumMember(Value = "executing")] Executing, - + [EnumMember(Value = "progress")] Progress, - + [EnumMember(Value = "executed")] Executed, } diff --git a/StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutionErrorData.cs b/StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutionErrorData.cs new file mode 100644 index 00000000..1a6842ca --- /dev/null +++ b/StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutionErrorData.cs @@ -0,0 +1,11 @@ +namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; + +public record ComfyWebSocketExecutionErrorData +{ + public required string PromptId { get; set; } + public string? NodeId { get; set; } + public string? NodeType { get; set; } + public string? ExceptionMessage { get; set; } + public string? ExceptionType { get; set; } + public string[]? Traceback { get; set; } +}