Browse Source

ComfyApi update

pull/165/head
Ionite 1 year ago
parent
commit
2e9d189239
No known key found for this signature in database
  1. 21
      StabilityMatrix.Core/Api/IComfyApi.cs
  2. 163
      StabilityMatrix.Core/Inference/ComfyClient.cs
  3. 103
      StabilityMatrix.Core/Inference/ComfyWebSocketClient.cs
  4. 6
      StabilityMatrix.Core/Inference/IInferenceClient.cs
  5. 30
      StabilityMatrix.Core/Inference/InferenceClientBase.cs
  6. 9
      StabilityMatrix.Core/Models/Api/Comfy/ComfyHistoryOutput.cs
  7. 9
      StabilityMatrix.Core/Models/Api/Comfy/ComfyHistoryResponse.cs
  8. 15
      StabilityMatrix.Core/Models/Api/Comfy/ComfyImage.cs
  9. 12
      StabilityMatrix.Core/Models/Api/Comfy/ComfyNode.cs
  10. 12
      StabilityMatrix.Core/Models/Api/Comfy/ComfyPromptRequest.cs
  11. 16
      StabilityMatrix.Core/Models/Api/Comfy/ComfyPromptResponse.cs
  12. 29
      StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs
  13. 29
      StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs
  14. 12
      StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseUnion.cs
  15. 9
      StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyStatus.cs
  16. 9
      StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyStatusExecInfo.cs
  17. 15
      StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutingData.cs
  18. 12
      StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketProgressData.cs
  19. 9
      StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketStatusData.cs

21
StabilityMatrix.Core/Api/IComfyApi.cs

@ -1,9 +1,28 @@
using Refit;
using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Core.Api;
[Headers("User-Agent: StabilityMatrix")]
public interface IComfyApi
{
[Post("/prompt")]
Task<ComfyPromptResponse> PostPrompt(
[Body] ComfyPromptRequest prompt,
CancellationToken cancellationToken = default
);
[Get("/history/{promptId}")]
Task<ComfyHistoryResponse> GetHistory(
string promptId,
CancellationToken cancellationToken = default
);
[Get("/view")]
Task<Stream> DownloadImage(
string filename,
string subfolder,
string type,
CancellationToken cancellationToken = default
);
}

163
StabilityMatrix.Core/Inference/ComfyClient.cs

@ -1,76 +1,135 @@
using System.Buffers;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using StabilityMatrix.Core.Helper;
using System.Net.WebSockets;
using NLog;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.Progress;
namespace StabilityMatrix.Core.Inference;
/// <summary>
/// Websocket client for Comfy inference server
/// Connects to localhost:8188 by default
/// </summary>
public class ComfyClient : IInferenceClient
public class ComfyClient : InferenceClientBase
{
private readonly ClientWebSocket clientWebSocket = new();
private readonly CancellationTokenSource cancellationTokenSource = new();
private readonly CancellationToken cancellationToken;
private readonly JsonSerializerOptions? jsonSerializerOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = true
};
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
protected Guid ClientId { get; } = Guid.NewGuid();
private readonly ComfyWebSocketClient webSocketClient = new();
private readonly IComfyApi comfyApi;
private readonly Uri baseAddress;
private bool isDisposed;
public ComfyClient()
// ReSharper disable once MemberCanBePrivate.Global
public string ClientId { get; private set; } = Guid.NewGuid().ToString();
public ComfyClient(IApiFactory apiFactory, Uri baseAddress)
{
cancellationToken = cancellationTokenSource.Token;
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(baseAddress);
this.baseAddress = baseAddress;
}
public async Task ConnectAsync(Uri uri)
public override async Task ConnectAsync(CancellationToken cancellationToken = default)
{
await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false);
await webSocketClient.ConnectAsync(baseAddress, ClientId).ConfigureAwait(false);
}
public async Task SendAsync<T>(T message)
public override async Task CloseAsync(CancellationToken cancellationToken = default)
{
var json = JsonSerializer.Serialize(message, jsonSerializerOptions);
var bytes = Encoding.UTF8.GetBytes(json);
var buffer = new ArraySegment<byte>(bytes);
await clientWebSocket
.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken)
.ConfigureAwait(false);
await webSocketClient.CloseAsync().ConfigureAwait(false);
}
public async Task<T?> ReceiveAsync<T>()
public Task<ComfyPromptResponse> QueuePromptAsync(
Dictionary<string, ComfyNode> nodes,
CancellationToken cancellationToken = default)
{
var shared = ArrayPool<byte>.Shared;
var buffer = shared.Rent(1024);
try
var request = new ComfyPromptRequest
{
var result = await clientWebSocket
.ReceiveAsync(buffer, cancellationToken)
.ConfigureAwait(false);
var json = Encoding.UTF8.GetString(buffer, 0, result.Count);
return JsonSerializer.Deserialize<T>(json, jsonSerializerOptions);
ClientId = ClientId,
Prompt = nodes,
};
return comfyApi.PostPrompt(request, cancellationToken);
}
public async Task<Dictionary<string, List<ComfyImage>?>> ExecutePromptAsync(
Dictionary<string, ComfyNode> nodes,
IProgress<ProgressReport>? progress = default,
CancellationToken cancellationToken = default)
{
var response = await QueuePromptAsync(nodes, cancellationToken).ConfigureAwait(false);
var promptId = response.PromptId;
while (true)
{
var message = await webSocketClient.ReceiveAsync().ConfigureAwait(false);
if (message is null)
{
Logger.Warn("Received null message");
break;
}
// Stop if closed
if (message.MessageType == WebSocketMessageType.Close)
{
Logger.Trace("Received close message");
break;
}
if (message.Json is { } json)
{
Logger.Trace("Received json message: (Type = {Type}, Data = {Data})",
json.Type, json.Data);
// Stop if we get an executing response with null Node property
if (json.Type is ComfyWebSocketResponseType.Executing)
{
var executingData = json.GetDataAsType<ComfyWebSocketExecutingData>();
// We need this to stop the loop, so if it's null, we'll throw
if (executingData is null)
{
throw new NullReferenceException("Could not parse executing data");
}
// Check this is for us
if (executingData.PromptId != promptId)
{
Logger.Trace("Received executing message for different prompt - ignoring");
continue;
}
if (executingData.Node is null)
{
Logger.Trace("Received executing message with null node - stopping");
break;
}
}
else if (json.Type is ComfyWebSocketResponseType.Progress)
{
var progressData = json.GetDataAsType<ComfyWebSocketProgressData>();
if (progressData is null)
{
Logger.Warn("Could not parse progress data");
continue;
}
progress?.Report(new ProgressReport
{
Current = Convert.ToUInt64(progressData.Value),
Total = Convert.ToUInt64(progressData.Max),
});
}
}
}
finally
// Get history for images
var history = await comfyApi.GetHistory(promptId, cancellationToken).ConfigureAwait(false);
var dict = new Dictionary<string, List<ComfyImage>?>();
foreach (var (nodeKey, output) in history.Outputs)
{
shared.Return(buffer);
dict[nodeKey] = output.Images;
}
return dict;
}
public async Task CloseAsync()
{
await clientWebSocket
.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken)
.ConfigureAwait(false);
}
public void Dispose()
protected override void Dispose(bool disposing)
{
clientWebSocket.Dispose();
cancellationTokenSource.Dispose();
if (isDisposed) return;
webSocketClient.Dispose();
isDisposed = true;
}
}

103
StabilityMatrix.Core/Inference/ComfyWebSocketClient.cs

@ -0,0 +1,103 @@
using System.Buffers;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Core.Inference;
/// <summary>
/// Websocket client for Comfy inference server
/// Connects to localhost:8188 by default
/// </summary>
public class ComfyWebSocketClient : IDisposable
{
private bool isDisposed;
private readonly ClientWebSocket clientWebSocket = new();
private readonly CancellationTokenSource cancellationTokenSource = new();
private readonly CancellationToken cancellationToken;
private readonly JsonSerializerOptions? jsonSerializerOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = true
};
public ComfyWebSocketClient()
{
cancellationToken = cancellationTokenSource.Token;
}
public async Task ConnectAsync(Uri baseAddress, string clientId)
{
var uri = new UriBuilder(baseAddress)
{
Path = "/ws",
Query = $"client_id={clientId}"
}.Uri;
await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false);
}
public async Task SendAsync<T>(T message)
{
var json = JsonSerializer.Serialize(message, jsonSerializerOptions);
var bytes = Encoding.UTF8.GetBytes(json);
var buffer = new ArraySegment<byte>(bytes);
await clientWebSocket
.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken)
.ConfigureAwait(false);
}
public async Task<ComfyWebSocketResponseUnion?> ReceiveAsync()
{
var shared = ArrayPool<byte>.Shared;
var buffer = shared.Rent(1024);
try
{
var result = await clientWebSocket
.ReceiveAsync(buffer, cancellationToken)
.ConfigureAwait(false);
if (result.MessageType is WebSocketMessageType.Binary)
{
return new ComfyWebSocketResponseUnion
{
MessageType = result.MessageType,
Json = null,
Bytes = buffer.AsSpan(0, result.Count).ToArray()
};
}
else
{
var text = Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count));
var json = JsonSerializer.Deserialize<ComfyWebSocketResponse>(text, jsonSerializerOptions);
return new ComfyWebSocketResponseUnion
{
MessageType = result.MessageType,
Json = json,
Bytes = null
};
}
}
finally
{
shared.Return(buffer);
}
}
public async Task CloseAsync()
{
await clientWebSocket
.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken)
.ConfigureAwait(false);
}
public void Dispose()
{
if (isDisposed) return;
clientWebSocket.Dispose();
cancellationTokenSource.Dispose();
isDisposed = true;
GC.SuppressFinalize(this);
}
}

6
StabilityMatrix.Core/Inference/IInferenceClient.cs

@ -1,6 +0,0 @@
namespace StabilityMatrix.Core.Inference;
public interface IInferenceClient
{
}

30
StabilityMatrix.Core/Inference/InferenceClientBase.cs

@ -0,0 +1,30 @@
namespace StabilityMatrix.Core.Inference;
public abstract class InferenceClientBase : IDisposable
{
/// <summary>
/// Start the connection
/// </summary>
public virtual Task ConnectAsync(CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
/// <summary>
/// Close the connection to remote resources
/// </summary>
public virtual Task CloseAsync(CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
protected virtual void Dispose(bool disposing)
{
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
}

9
StabilityMatrix.Core/Models/Api/Comfy/ComfyHistoryOutput.cs

@ -0,0 +1,9 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public class ComfyHistoryOutput
{
[JsonPropertyName("images")]
public List<ComfyImage>? Images { get; set; }
}

9
StabilityMatrix.Core/Models/Api/Comfy/ComfyHistoryResponse.cs

@ -0,0 +1,9 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public class ComfyHistoryResponse
{
[JsonPropertyName("outputs")]
public required Dictionary<string, ComfyHistoryOutput> Outputs { get; set; }
}

15
StabilityMatrix.Core/Models/Api/Comfy/ComfyImage.cs

@ -0,0 +1,15 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public class ComfyImage
{
[JsonPropertyName("filename")]
public required string FileName { get; set; }
[JsonPropertyName("subfolder")]
public required string SubFolder { get; set; }
[JsonPropertyName("type")]
public required string Type { get; set; }
}

12
StabilityMatrix.Core/Models/Api/Comfy/ComfyNode.cs

@ -0,0 +1,12 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public class ComfyNode
{
[JsonPropertyName("class_type")]
public required string ClassType;
[JsonPropertyName("inputs")]
public required Dictionary<string, object?> Inputs;
}

12
StabilityMatrix.Core/Models/Api/Comfy/ComfyPromptRequest.cs

@ -0,0 +1,12 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public class ComfyPromptRequest
{
[JsonPropertyName("client_id")]
public required string ClientId;
[JsonPropertyName("prompt")]
public required Dictionary<string, ComfyNode> Prompt;
}

16
StabilityMatrix.Core/Models/Api/Comfy/ComfyPromptResponse.cs

@ -0,0 +1,16 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy;
// ReSharper disable once ClassNeverInstantiated.Global
public class ComfyPromptResponse
{
[JsonPropertyName("prompt_id")]
public required string PromptId { get; set; }
[JsonPropertyName("number")]
public required int Number { get; set; }
[JsonPropertyName("node_errors")]
public required Dictionary<string, object?> NodeErrors { get; set; }
}

29
StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponse.cs

@ -0,0 +1,29 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public class ComfyWebSocketResponse
{
[JsonPropertyName("type")]
public required ComfyWebSocketResponseType Type { get; set; }
/// <summary>
/// Depending on the value of <see cref="Type"/>,
/// this property will be one of these types
/// <list type="bullet">
/// <item>Status - <see cref="ComfyWebSocketStatusData"/></item>
/// <item>Progress - <see cref="ComfyWebSocketProgressData"/></item>
/// <item>Executing - <see cref="ComfyWebSocketExecutingData"/></item>
/// </list>
/// </summary>
[JsonPropertyName("data")]
public required JsonObject Data { get; set; }
public T? GetDataAsType<T>() where T : class
{
return Data.Deserialize<T>();
}
}

29
StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseType.cs

@ -0,0 +1,29 @@
using System.Runtime.Serialization;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Converters.Json;
namespace StabilityMatrix.Core.Models.Api.Comfy;
[JsonConverter(typeof(DefaultUnknownEnumConverter<ComfyWebSocketResponseType>))]
public enum ComfyWebSocketResponseType
{
Unknown,
[EnumMember(Value = "status")]
Status,
[EnumMember(Value = "execution_start")]
ExecutionStart,
[EnumMember(Value = "execution_cached")]
ExecutionCached,
[EnumMember(Value = "executing")]
Executing,
[EnumMember(Value = "progress")]
Progress,
[EnumMember(Value = "executed")]
Executed,
}

12
StabilityMatrix.Core/Models/Api/Comfy/ComfyWebSocketResponseUnion.cs

@ -0,0 +1,12 @@
using System.Net.WebSockets;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public record ComfyWebSocketResponseUnion
{
public WebSocketMessageType MessageType { get; set; }
public ComfyWebSocketResponse? Json { get; set; }
public byte[]? Bytes { get; set; }
};

9
StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyStatus.cs

@ -0,0 +1,9 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
public record struct ComfyStatus
{
[JsonPropertyName("exec_info")]
public required int ExecInfo { get; set; }
}

9
StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyStatusExecInfo.cs

@ -0,0 +1,9 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
public record ComfyStatusExecInfo
{
[JsonPropertyName("queue_remaining")]
public required int QueueRemaining { get; set; }
}

15
StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketExecutingData.cs

@ -0,0 +1,15 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
public class ComfyWebSocketExecutingData
{
[JsonPropertyName("prompt_id")]
public required string PromptId { get; set; }
/// <summary>
/// When this is null it indicates completed
/// </summary>
[JsonPropertyName("node")]
public required int? Node { get; set; }
}

12
StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketProgressData.cs

@ -0,0 +1,12 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
public record ComfyWebSocketProgressData
{
[JsonPropertyName("value")]
public required int Value { get; set; }
[JsonPropertyName("max")]
public required int Max { get; set; }
}

9
StabilityMatrix.Core/Models/Api/Comfy/WebSocketData/ComfyWebSocketStatusData.cs

@ -0,0 +1,9 @@
using System.Text.Json.Serialization;
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
public record ComfyWebSocketStatusData
{
[JsonPropertyName("status")]
public required ComfyStatus Status { get; set; }
}
Loading…
Cancel
Save