Ionite
1 year ago
19 changed files with 461 additions and 59 deletions
@ -1,9 +1,28 @@
|
||||
using Refit; |
||||
using StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
namespace StabilityMatrix.Core.Api; |
||||
|
||||
[Headers("User-Agent: StabilityMatrix")] |
||||
public interface IComfyApi |
||||
{ |
||||
|
||||
[Post("/prompt")] |
||||
Task<ComfyPromptResponse> PostPrompt( |
||||
[Body] ComfyPromptRequest prompt, |
||||
CancellationToken cancellationToken = default |
||||
); |
||||
|
||||
[Get("/history/{promptId}")] |
||||
Task<ComfyHistoryResponse> GetHistory( |
||||
string promptId, |
||||
CancellationToken cancellationToken = default |
||||
); |
||||
|
||||
[Get("/view")] |
||||
Task<Stream> DownloadImage( |
||||
string filename, |
||||
string subfolder, |
||||
string type, |
||||
CancellationToken cancellationToken = default |
||||
); |
||||
} |
||||
|
@ -1,76 +1,135 @@
|
||||
using System.Buffers; |
||||
using System.Net.WebSockets; |
||||
using System.Text; |
||||
using System.Text.Json; |
||||
using StabilityMatrix.Core.Helper; |
||||
using System.Net.WebSockets; |
||||
using NLog; |
||||
using StabilityMatrix.Core.Api; |
||||
using StabilityMatrix.Core.Models.Api.Comfy; |
||||
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
using StabilityMatrix.Core.Models.Progress; |
||||
|
||||
namespace StabilityMatrix.Core.Inference; |
||||
|
||||
/// <summary> |
||||
/// Websocket client for Comfy inference server |
||||
/// Connects to localhost:8188 by default |
||||
/// </summary> |
||||
public class ComfyClient : IInferenceClient |
||||
public class ComfyClient : InferenceClientBase |
||||
{ |
||||
private readonly ClientWebSocket clientWebSocket = new(); |
||||
private readonly CancellationTokenSource cancellationTokenSource = new(); |
||||
private readonly CancellationToken cancellationToken; |
||||
private readonly JsonSerializerOptions? jsonSerializerOptions = new() |
||||
{ |
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase, |
||||
WriteIndented = true |
||||
}; |
||||
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
||||
|
||||
protected Guid ClientId { get; } = Guid.NewGuid(); |
||||
private readonly ComfyWebSocketClient webSocketClient = new(); |
||||
private readonly IComfyApi comfyApi; |
||||
private readonly Uri baseAddress; |
||||
private bool isDisposed; |
||||
|
||||
public ComfyClient() |
||||
// ReSharper disable once MemberCanBePrivate.Global |
||||
public string ClientId { get; private set; } = Guid.NewGuid().ToString(); |
||||
|
||||
public ComfyClient(IApiFactory apiFactory, Uri baseAddress) |
||||
{ |
||||
cancellationToken = cancellationTokenSource.Token; |
||||
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(baseAddress); |
||||
this.baseAddress = baseAddress; |
||||
} |
||||
|
||||
public async Task ConnectAsync(Uri uri) |
||||
public override async Task ConnectAsync(CancellationToken cancellationToken = default) |
||||
{ |
||||
await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); |
||||
await webSocketClient.ConnectAsync(baseAddress, ClientId).ConfigureAwait(false); |
||||
} |
||||
|
||||
public async Task SendAsync<T>(T message) |
||||
public override async Task CloseAsync(CancellationToken cancellationToken = default) |
||||
{ |
||||
var json = JsonSerializer.Serialize(message, jsonSerializerOptions); |
||||
var bytes = Encoding.UTF8.GetBytes(json); |
||||
var buffer = new ArraySegment<byte>(bytes); |
||||
await clientWebSocket |
||||
.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken) |
||||
.ConfigureAwait(false); |
||||
await webSocketClient.CloseAsync().ConfigureAwait(false); |
||||
} |
||||
|
||||
public async Task<T?> ReceiveAsync<T>() |
||||
|
||||
public Task<ComfyPromptResponse> QueuePromptAsync( |
||||
Dictionary<string, ComfyNode> nodes, |
||||
CancellationToken cancellationToken = default) |
||||
{ |
||||
var shared = ArrayPool<byte>.Shared; |
||||
var buffer = shared.Rent(1024); |
||||
try |
||||
var request = new ComfyPromptRequest |
||||
{ |
||||
var result = await clientWebSocket |
||||
.ReceiveAsync(buffer, cancellationToken) |
||||
.ConfigureAwait(false); |
||||
var json = Encoding.UTF8.GetString(buffer, 0, result.Count); |
||||
return JsonSerializer.Deserialize<T>(json, jsonSerializerOptions); |
||||
ClientId = ClientId, |
||||
Prompt = nodes, |
||||
}; |
||||
return comfyApi.PostPrompt(request, cancellationToken); |
||||
} |
||||
|
||||
public async Task<Dictionary<string, List<ComfyImage>?>> ExecutePromptAsync( |
||||
Dictionary<string, ComfyNode> nodes, |
||||
IProgress<ProgressReport>? progress = default, |
||||
CancellationToken cancellationToken = default) |
||||
{ |
||||
var response = await QueuePromptAsync(nodes, cancellationToken).ConfigureAwait(false); |
||||
var promptId = response.PromptId; |
||||
|
||||
while (true) |
||||
{ |
||||
var message = await webSocketClient.ReceiveAsync().ConfigureAwait(false); |
||||
|
||||
if (message is null) |
||||
{ |
||||
Logger.Warn("Received null message"); |
||||
break; |
||||
} |
||||
|
||||
// Stop if closed |
||||
if (message.MessageType == WebSocketMessageType.Close) |
||||
{ |
||||
Logger.Trace("Received close message"); |
||||
break; |
||||
} |
||||
|
||||
if (message.Json is { } json) |
||||
{ |
||||
Logger.Trace("Received json message: (Type = {Type}, Data = {Data})", |
||||
json.Type, json.Data); |
||||
|
||||
// Stop if we get an executing response with null Node property |
||||
if (json.Type is ComfyWebSocketResponseType.Executing) |
||||
{ |
||||
var executingData = json.GetDataAsType<ComfyWebSocketExecutingData>(); |
||||
// We need this to stop the loop, so if it's null, we'll throw |
||||
if (executingData is null) |
||||
{ |
||||
throw new NullReferenceException("Could not parse executing data"); |
||||
} |
||||
// Check this is for us |
||||
if (executingData.PromptId != promptId) |
||||
{ |
||||
Logger.Trace("Received executing message for different prompt - ignoring"); |
||||
continue; |
||||
} |
||||
if (executingData.Node is null) |
||||
{ |
||||
Logger.Trace("Received executing message with null node - stopping"); |
||||
break; |
||||
} |
||||
} |
||||
else if (json.Type is ComfyWebSocketResponseType.Progress) |
||||
{ |
||||
var progressData = json.GetDataAsType<ComfyWebSocketProgressData>(); |
||||
if (progressData is null) |
||||
{ |
||||
Logger.Warn("Could not parse progress data"); |
||||
continue; |
||||
} |
||||
progress?.Report(new ProgressReport |
||||
{ |
||||
Current = Convert.ToUInt64(progressData.Value), |
||||
Total = Convert.ToUInt64(progressData.Max), |
||||
}); |
||||
} |
||||
} |
||||
} |
||||
finally |
||||
|
||||
// Get history for images |
||||
var history = await comfyApi.GetHistory(promptId, cancellationToken).ConfigureAwait(false); |
||||
|
||||
var dict = new Dictionary<string, List<ComfyImage>?>(); |
||||
foreach (var (nodeKey, output) in history.Outputs) |
||||
{ |
||||
shared.Return(buffer); |
||||
dict[nodeKey] = output.Images; |
||||
} |
||||
return dict; |
||||
} |
||||
|
||||
public async Task CloseAsync() |
||||
{ |
||||
await clientWebSocket |
||||
.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken) |
||||
.ConfigureAwait(false); |
||||
} |
||||
|
||||
public void Dispose() |
||||
protected override void Dispose(bool disposing) |
||||
{ |
||||
clientWebSocket.Dispose(); |
||||
cancellationTokenSource.Dispose(); |
||||
if (isDisposed) return; |
||||
webSocketClient.Dispose(); |
||||
isDisposed = true; |
||||
} |
||||
} |
||||
|
@ -0,0 +1,103 @@
|
||||
using System.Buffers; |
||||
using System.Net.WebSockets; |
||||
using System.Text; |
||||
using System.Text.Json; |
||||
using StabilityMatrix.Core.Helper; |
||||
using StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
namespace StabilityMatrix.Core.Inference; |
||||
|
||||
/// <summary> |
||||
/// Websocket client for Comfy inference server |
||||
/// Connects to localhost:8188 by default |
||||
/// </summary> |
||||
public class ComfyWebSocketClient : IDisposable |
||||
{ |
||||
private bool isDisposed; |
||||
private readonly ClientWebSocket clientWebSocket = new(); |
||||
private readonly CancellationTokenSource cancellationTokenSource = new(); |
||||
private readonly CancellationToken cancellationToken; |
||||
private readonly JsonSerializerOptions? jsonSerializerOptions = new() |
||||
{ |
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase, |
||||
WriteIndented = true |
||||
}; |
||||
|
||||
public ComfyWebSocketClient() |
||||
{ |
||||
cancellationToken = cancellationTokenSource.Token; |
||||
} |
||||
|
||||
public async Task ConnectAsync(Uri baseAddress, string clientId) |
||||
{ |
||||
var uri = new UriBuilder(baseAddress) |
||||
{ |
||||
Path = "/ws", |
||||
Query = $"client_id={clientId}" |
||||
}.Uri; |
||||
await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); |
||||
} |
||||
|
||||
public async Task SendAsync<T>(T message) |
||||
{ |
||||
var json = JsonSerializer.Serialize(message, jsonSerializerOptions); |
||||
var bytes = Encoding.UTF8.GetBytes(json); |
||||
var buffer = new ArraySegment<byte>(bytes); |
||||
await clientWebSocket |
||||
.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken) |
||||
.ConfigureAwait(false); |
||||
} |
||||
|
||||
public async Task<ComfyWebSocketResponseUnion?> ReceiveAsync() |
||||
{ |
||||
var shared = ArrayPool<byte>.Shared; |
||||
var buffer = shared.Rent(1024); |
||||
try |
||||
{ |
||||
var result = await clientWebSocket |
||||
.ReceiveAsync(buffer, cancellationToken) |
||||
.ConfigureAwait(false); |
||||
|
||||
if (result.MessageType is WebSocketMessageType.Binary) |
||||
{ |
||||
return new ComfyWebSocketResponseUnion |
||||
{ |
||||
MessageType = result.MessageType, |
||||
Json = null, |
||||
Bytes = buffer.AsSpan(0, result.Count).ToArray() |
||||
}; |
||||
} |
||||
else |
||||
{ |
||||
var text = Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count)); |
||||
var json = JsonSerializer.Deserialize<ComfyWebSocketResponse>(text, jsonSerializerOptions); |
||||
return new ComfyWebSocketResponseUnion |
||||
{ |
||||
MessageType = result.MessageType, |
||||
Json = json, |
||||
Bytes = null |
||||
}; |
||||
} |
||||
} |
||||
finally |
||||
{ |
||||
shared.Return(buffer); |
||||
} |
||||
} |
||||
|
||||
public async Task CloseAsync() |
||||
{ |
||||
await clientWebSocket |
||||
.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken) |
||||
.ConfigureAwait(false); |
||||
} |
||||
|
||||
public void Dispose() |
||||
{ |
||||
if (isDisposed) return; |
||||
clientWebSocket.Dispose(); |
||||
cancellationTokenSource.Dispose(); |
||||
isDisposed = true; |
||||
GC.SuppressFinalize(this); |
||||
} |
||||
} |
@ -1,6 +0,0 @@
|
||||
namespace StabilityMatrix.Core.Inference; |
||||
|
||||
public interface IInferenceClient |
||||
{ |
||||
|
||||
} |
@ -0,0 +1,30 @@
|
||||
namespace StabilityMatrix.Core.Inference; |
||||
|
||||
public abstract class InferenceClientBase : IDisposable |
||||
{ |
||||
/// <summary> |
||||
/// Start the connection |
||||
/// </summary> |
||||
public virtual Task ConnectAsync(CancellationToken cancellationToken = default) |
||||
{ |
||||
return Task.CompletedTask; |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Close the connection to remote resources |
||||
/// </summary> |
||||
public virtual Task CloseAsync(CancellationToken cancellationToken = default) |
||||
{ |
||||
return Task.CompletedTask; |
||||
} |
||||
|
||||
protected virtual void Dispose(bool disposing) |
||||
{ |
||||
} |
||||
|
||||
public void Dispose() |
||||
{ |
||||
Dispose(true); |
||||
GC.SuppressFinalize(this); |
||||
} |
||||
} |
@ -0,0 +1,9 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public class ComfyHistoryOutput |
||||
{ |
||||
[JsonPropertyName("images")] |
||||
public List<ComfyImage>? Images { get; set; } |
||||
} |
@ -0,0 +1,9 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public class ComfyHistoryResponse |
||||
{ |
||||
[JsonPropertyName("outputs")] |
||||
public required Dictionary<string, ComfyHistoryOutput> Outputs { get; set; } |
||||
} |
@ -0,0 +1,15 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public class ComfyImage |
||||
{ |
||||
[JsonPropertyName("filename")] |
||||
public required string FileName { get; set; } |
||||
|
||||
[JsonPropertyName("subfolder")] |
||||
public required string SubFolder { get; set; } |
||||
|
||||
[JsonPropertyName("type")] |
||||
public required string Type { get; set; } |
||||
} |
@ -0,0 +1,12 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public class ComfyNode |
||||
{ |
||||
[JsonPropertyName("class_type")] |
||||
public required string ClassType; |
||||
|
||||
[JsonPropertyName("inputs")] |
||||
public required Dictionary<string, object?> Inputs; |
||||
} |
@ -0,0 +1,12 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public class ComfyPromptRequest |
||||
{ |
||||
[JsonPropertyName("client_id")] |
||||
public required string ClientId; |
||||
|
||||
[JsonPropertyName("prompt")] |
||||
public required Dictionary<string, ComfyNode> Prompt; |
||||
} |
@ -0,0 +1,16 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
// ReSharper disable once ClassNeverInstantiated.Global |
||||
public class ComfyPromptResponse |
||||
{ |
||||
[JsonPropertyName("prompt_id")] |
||||
public required string PromptId { get; set; } |
||||
|
||||
[JsonPropertyName("number")] |
||||
public required int Number { get; set; } |
||||
|
||||
[JsonPropertyName("node_errors")] |
||||
public required Dictionary<string, object?> NodeErrors { get; set; } |
||||
} |
@ -0,0 +1,29 @@
|
||||
using System.Text.Json; |
||||
using System.Text.Json.Nodes; |
||||
using System.Text.Json.Serialization; |
||||
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public class ComfyWebSocketResponse |
||||
{ |
||||
[JsonPropertyName("type")] |
||||
public required ComfyWebSocketResponseType Type { get; set; } |
||||
|
||||
/// <summary> |
||||
/// Depending on the value of <see cref="Type"/>, |
||||
/// this property will be one of these types |
||||
/// <list type="bullet"> |
||||
/// <item>Status - <see cref="ComfyWebSocketStatusData"/></item> |
||||
/// <item>Progress - <see cref="ComfyWebSocketProgressData"/></item> |
||||
/// <item>Executing - <see cref="ComfyWebSocketExecutingData"/></item> |
||||
/// </list> |
||||
/// </summary> |
||||
[JsonPropertyName("data")] |
||||
public required JsonObject Data { get; set; } |
||||
|
||||
public T? GetDataAsType<T>() where T : class
|
||||
{ |
||||
return Data.Deserialize<T>(); |
||||
} |
||||
} |
@ -0,0 +1,29 @@
|
||||
using System.Runtime.Serialization; |
||||
using System.Text.Json.Serialization; |
||||
using StabilityMatrix.Core.Converters.Json; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
[JsonConverter(typeof(DefaultUnknownEnumConverter<ComfyWebSocketResponseType>))] |
||||
public enum ComfyWebSocketResponseType |
||||
{ |
||||
Unknown, |
||||
|
||||
[EnumMember(Value = "status")] |
||||
Status, |
||||
|
||||
[EnumMember(Value = "execution_start")] |
||||
ExecutionStart, |
||||
|
||||
[EnumMember(Value = "execution_cached")] |
||||
ExecutionCached, |
||||
|
||||
[EnumMember(Value = "executing")] |
||||
Executing, |
||||
|
||||
[EnumMember(Value = "progress")] |
||||
Progress, |
||||
|
||||
[EnumMember(Value = "executed")] |
||||
Executed, |
||||
} |
@ -0,0 +1,12 @@
|
||||
using System.Net.WebSockets; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy; |
||||
|
||||
public record ComfyWebSocketResponseUnion |
||||
{ |
||||
public WebSocketMessageType MessageType { get; set; } |
||||
public ComfyWebSocketResponse? Json { get; set; } |
||||
public byte[]? Bytes { get; set; } |
||||
}; |
||||
|
||||
|
@ -0,0 +1,9 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
|
||||
public record struct ComfyStatus |
||||
{ |
||||
[JsonPropertyName("exec_info")] |
||||
public required int ExecInfo { get; set; } |
||||
} |
@ -0,0 +1,9 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
|
||||
public record ComfyStatusExecInfo |
||||
{ |
||||
[JsonPropertyName("queue_remaining")] |
||||
public required int QueueRemaining { get; set; } |
||||
} |
@ -0,0 +1,15 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
|
||||
public class ComfyWebSocketExecutingData |
||||
{ |
||||
[JsonPropertyName("prompt_id")] |
||||
public required string PromptId { get; set; } |
||||
|
||||
/// <summary> |
||||
/// When this is null it indicates completed |
||||
/// </summary> |
||||
[JsonPropertyName("node")] |
||||
public required int? Node { get; set; } |
||||
} |
@ -0,0 +1,12 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
|
||||
public record ComfyWebSocketProgressData |
||||
{ |
||||
[JsonPropertyName("value")] |
||||
public required int Value { get; set; } |
||||
|
||||
[JsonPropertyName("max")] |
||||
public required int Max { get; set; } |
||||
} |
@ -0,0 +1,9 @@
|
||||
using System.Text.Json.Serialization; |
||||
|
||||
namespace StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
||||
|
||||
public record ComfyWebSocketStatusData |
||||
{ |
||||
[JsonPropertyName("status")] |
||||
public required ComfyStatus Status { get; set; } |
||||
} |
Loading…
Reference in new issue