|
|
|
using System.Collections.Concurrent;
|
|
|
|
using System.Collections.Immutable;
|
|
|
|
using System.Net.WebSockets;
|
|
|
|
using System.Text.Json;
|
|
|
|
using NLog;
|
|
|
|
using Polly.Contrib.WaitAndRetry;
|
|
|
|
using StabilityMatrix.Core.Api;
|
|
|
|
using StabilityMatrix.Core.Extensions;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
|
|
|
|
using StabilityMatrix.Core.Models.FileInterfaces;
|
|
|
|
using Websocket.Client;
|
|
|
|
using Websocket.Client.Exceptions;
|
|
|
|
|
|
|
|
namespace StabilityMatrix.Core.Inference;
|
|
|
|
|
|
|
|
public class ComfyClient : InferenceClientBase
|
|
|
|
{
|
|
|
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
|
|
|
|
|
|
|
|
private readonly WebsocketClient webSocketClient;
|
|
|
|
private readonly IComfyApi comfyApi;
|
|
|
|
private bool isDisposed;
|
|
|
|
|
|
|
|
// ReSharper disable once MemberCanBePrivate.Global
|
|
|
|
public string ClientId { get; } = Guid.NewGuid().ToString();
|
|
|
|
|
|
|
|
public Uri BaseAddress { get; }
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Optional local path to output images.
|
|
|
|
/// </summary>
|
|
|
|
public DirectoryPath? OutputImagesDir { get; set; }
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Optional local path to input images.
|
|
|
|
/// </summary>
|
|
|
|
public DirectoryPath? InputImagesDir { get; set; }
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Dictionary of ongoing prompt execution tasks
|
|
|
|
/// </summary>
|
|
|
|
public ConcurrentDictionary<string, ComfyTask> PromptTasks { get; } = new();
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Current running prompt task
|
|
|
|
/// </summary>
|
|
|
|
private ComfyTask? currentPromptTask;
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Event raised when a progress update is received from the server
|
|
|
|
/// </summary>
|
|
|
|
public event EventHandler<ComfyWebSocketProgressData>? ProgressUpdateReceived;
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Event raised when a status update is received from the server
|
|
|
|
/// </summary>
|
|
|
|
public event EventHandler<ComfyWebSocketStatusData>? StatusUpdateReceived;
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Event raised when a executing update is received from the server
|
|
|
|
/// </summary>
|
|
|
|
public event EventHandler<ComfyWebSocketExecutingData>? ExecutingUpdateReceived;
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Event raised when a preview image is received from the server
|
|
|
|
/// </summary>
|
|
|
|
public event EventHandler<ComfyWebSocketImageData>? PreviewImageReceived;
|
|
|
|
|
|
|
|
public ComfyClient(IApiFactory apiFactory, Uri baseAddress)
|
|
|
|
{
|
|
|
|
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(baseAddress);
|
|
|
|
BaseAddress = baseAddress;
|
|
|
|
|
|
|
|
// Setup websocket client
|
|
|
|
var wsUri = new UriBuilder(baseAddress)
|
|
|
|
{
|
|
|
|
Scheme = "ws",
|
|
|
|
Path = "/ws",
|
|
|
|
Query = $"clientId={ClientId}"
|
|
|
|
}.Uri;
|
|
|
|
|
|
|
|
webSocketClient = new WebsocketClient(wsUri)
|
|
|
|
{
|
|
|
|
Name = "ComfyClient",
|
|
|
|
ReconnectTimeout = TimeSpan.FromSeconds(30)
|
|
|
|
};
|
|
|
|
|
|
|
|
webSocketClient.DisconnectionHappened.Subscribe(
|
|
|
|
info => Logger.Info("Websocket Disconnected, ({Type})", info.Type)
|
|
|
|
);
|
|
|
|
webSocketClient.ReconnectionHappened.Subscribe(
|
|
|
|
info => Logger.Info("Websocket Reconnected, ({Type})", info.Type)
|
|
|
|
);
|
|
|
|
|
|
|
|
webSocketClient.MessageReceived.Subscribe(OnMessageReceived);
|
|
|
|
}
|
|
|
|
|
|
|
|
private void OnMessageReceived(ResponseMessage message)
|
|
|
|
{
|
|
|
|
switch (message.MessageType)
|
|
|
|
{
|
|
|
|
case WebSocketMessageType.Text:
|
|
|
|
HandleTextMessage(message.Text);
|
|
|
|
break;
|
|
|
|
case WebSocketMessageType.Binary:
|
|
|
|
HandleBinaryMessage(message.Binary);
|
|
|
|
break;
|
|
|
|
case WebSocketMessageType.Close:
|
|
|
|
Logger.Trace("Received ws close message: {Text}", message.Text);
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
throw new ArgumentOutOfRangeException(nameof(message));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private void HandleTextMessage(string text)
|
|
|
|
{
|
|
|
|
ComfyWebSocketResponse? json;
|
|
|
|
try
|
|
|
|
{
|
|
|
|
json = JsonSerializer.Deserialize<ComfyWebSocketResponse>(text);
|
|
|
|
}
|
|
|
|
catch (JsonException e)
|
|
|
|
{
|
|
|
|
Logger.Warn($"Failed to parse json {text} ({e}), skipping");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (json is null)
|
|
|
|
{
|
|
|
|
Logger.Warn($"Could not parse json {text}, skipping");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
Logger.Trace("Received json message: (Type = {Type}, Data = {Data})", json.Type, json.Data);
|
|
|
|
|
|
|
|
if (json.Type == ComfyWebSocketResponseType.Executing)
|
|
|
|
{
|
|
|
|
var executingData = json.GetDataAsType<ComfyWebSocketExecutingData>();
|
|
|
|
if (executingData?.PromptId is null)
|
|
|
|
{
|
|
|
|
Logger.Warn($"Could not parse executing data {json.Data}, skipping");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// When Node property is null, it means the prompt has finished executing
|
|
|
|
// remove the task from the dictionary and set the result
|
|
|
|
if (executingData.Node is null)
|
|
|
|
{
|
|
|
|
if (PromptTasks.TryRemove(executingData.PromptId, out var task))
|
|
|
|
{
|
|
|
|
task.RunningNode = null;
|
|
|
|
task.SetResult();
|
|
|
|
currentPromptTask = null;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
Logger.Warn(
|
|
|
|
$"Could not find task for prompt {executingData.PromptId}, skipping"
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Otherwise set the task's active node to the one received
|
|
|
|
else
|
|
|
|
{
|
|
|
|
if (PromptTasks.TryGetValue(executingData.PromptId, out var task))
|
|
|
|
{
|
|
|
|
task.RunningNode = executingData.Node;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ExecutingUpdateReceived?.Invoke(this, executingData);
|
|
|
|
}
|
|
|
|
else if (json.Type == ComfyWebSocketResponseType.Status)
|
|
|
|
{
|
|
|
|
var statusData = json.GetDataAsType<ComfyWebSocketStatusData>();
|
|
|
|
if (statusData is null)
|
|
|
|
{
|
|
|
|
Logger.Warn($"Could not parse status data {json.Data}, skipping");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
StatusUpdateReceived?.Invoke(this, statusData);
|
|
|
|
}
|
|
|
|
else if (json.Type == ComfyWebSocketResponseType.Progress)
|
|
|
|
{
|
|
|
|
var progressData = json.GetDataAsType<ComfyWebSocketProgressData>();
|
|
|
|
if (progressData is null)
|
|
|
|
{
|
|
|
|
Logger.Warn($"Could not parse progress data {json.Data}, skipping");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set for the current prompt task
|
|
|
|
currentPromptTask?.OnProgressUpdate(progressData);
|
|
|
|
|
|
|
|
ProgressUpdateReceived?.Invoke(this, progressData);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
Logger.Warn($"Unknown message type {json.Type} ({json.Data}), skipping");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Parses binary data (previews) into image streams
|
|
|
|
/// https://github.com/comfyanonymous/ComfyUI/blob/master/server.py#L518
|
|
|
|
/// </summary>
|
|
|
|
private void HandleBinaryMessage(byte[] data)
|
|
|
|
{
|
|
|
|
if (data is not { Length: > 4 })
|
|
|
|
{
|
|
|
|
Logger.Warn("The input data is null or not long enough.");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The first 4 bytes is int32 of the message type
|
|
|
|
// Subsequent 4 bytes following is int32 of the image format
|
|
|
|
// The rest is the image data
|
|
|
|
|
|
|
|
// Read the image type from the first 4 bytes of the data.
|
|
|
|
// Python's struct.pack(">I", type_num) will pack the data as a big-endian unsigned int
|
|
|
|
/*var typeBytes = new byte[4];
|
|
|
|
stream.Read(typeBytes, 0, 4);
|
|
|
|
var imageType = BitConverter.ToInt32(typeBytes, 0);*/
|
|
|
|
|
|
|
|
/*if (!BitConverter.IsLittleEndian)
|
|
|
|
{
|
|
|
|
Array.Reverse(typeBytes);
|
|
|
|
}*/
|
|
|
|
|
|
|
|
PreviewImageReceived?.Invoke(this, new ComfyWebSocketImageData { ImageBytes = data[8..], });
|
|
|
|
}
|
|
|
|
|
|
|
|
public override async Task ConnectAsync(CancellationToken cancellationToken = default)
|
|
|
|
{
|
|
|
|
var delays = Backoff
|
|
|
|
.DecorrelatedJitterBackoffV2(TimeSpan.FromMilliseconds(500), retryCount: 5)
|
|
|
|
.ToImmutableArray();
|
|
|
|
|
|
|
|
foreach (var (i, retryDelay) in delays.Enumerate())
|
|
|
|
{
|
|
|
|
cancellationToken.ThrowIfCancellationRequested();
|
|
|
|
|
|
|
|
try
|
|
|
|
{
|
|
|
|
await webSocketClient.StartOrFail().ConfigureAwait(false);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
catch (WebsocketException e)
|
|
|
|
{
|
|
|
|
Logger.Info(
|
|
|
|
"Failed to connect to websocket, retrying in {RetryDelay} ({Message})",
|
|
|
|
retryDelay,
|
|
|
|
e.Message
|
|
|
|
);
|
|
|
|
|
|
|
|
if (i == delays.Length - 1)
|
|
|
|
{
|
|
|
|
throw;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
public override async Task CloseAsync(CancellationToken cancellationToken = default)
|
|
|
|
{
|
|
|
|
await webSocketClient
|
|
|
|
.Stop(WebSocketCloseStatus.NormalClosure, string.Empty)
|
|
|
|
.ConfigureAwait(false);
|
|
|
|
}
|
|
|
|
|
|
|
|
public async Task<ComfyTask> QueuePromptAsync(
|
|
|
|
Dictionary<string, ComfyNode> nodes,
|
|
|
|
CancellationToken cancellationToken = default
|
|
|
|
)
|
|
|
|
{
|
|
|
|
var request = new ComfyPromptRequest { ClientId = ClientId, Prompt = nodes };
|
|
|
|
var result = await comfyApi.PostPrompt(request, cancellationToken).ConfigureAwait(false);
|
|
|
|
|
|
|
|
// Add task to dictionary and set it as the current task
|
|
|
|
var task = new ComfyTask(result.PromptId);
|
|
|
|
PromptTasks[result.PromptId] = task;
|
|
|
|
currentPromptTask = task;
|
|
|
|
|
|
|
|
return task;
|
|
|
|
}
|
|
|
|
|
|
|
|
public async Task InterruptPromptAsync(CancellationToken cancellationToken = default)
|
|
|
|
{
|
|
|
|
await comfyApi.PostInterrupt(cancellationToken).ConfigureAwait(false);
|
|
|
|
|
|
|
|
// Set the current task to null, and remove it from the dictionary
|
|
|
|
if (currentPromptTask is { } task)
|
|
|
|
{
|
|
|
|
PromptTasks.TryRemove(task.Id, out _);
|
|
|
|
task.TrySetCanceled(cancellationToken);
|
|
|
|
task.Dispose();
|
|
|
|
currentPromptTask = null;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync(
|
|
|
|
string promptId,
|
|
|
|
CancellationToken cancellationToken = default
|
|
|
|
)
|
|
|
|
{
|
|
|
|
// Get history for images
|
|
|
|
var history = await comfyApi.GetHistory(promptId, cancellationToken).ConfigureAwait(false);
|
|
|
|
|
|
|
|
// Get the current prompt history
|
|
|
|
var current = history[promptId];
|
|
|
|
|
|
|
|
var dict = new Dictionary<string, List<ComfyImage>?>();
|
|
|
|
foreach (var (nodeKey, output) in current.Outputs)
|
|
|
|
{
|
|
|
|
dict[nodeKey] = output.Images;
|
|
|
|
}
|
|
|
|
return dict;
|
|
|
|
}
|
|
|
|
|
|
|
|
public async Task<Stream> GetImageStreamAsync(
|
|
|
|
ComfyImage comfyImage,
|
|
|
|
CancellationToken cancellationToken = default
|
|
|
|
)
|
|
|
|
{
|
|
|
|
var response = await comfyApi
|
|
|
|
.GetImage(comfyImage.FileName, comfyImage.SubFolder, comfyImage.Type, cancellationToken)
|
|
|
|
.ConfigureAwait(false);
|
|
|
|
return response;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Get a list of strings representing available model names
|
|
|
|
/// </summary>
|
|
|
|
public Task<List<string>?> GetModelNamesAsync(CancellationToken cancellationToken = default)
|
|
|
|
{
|
|
|
|
return GetNodeOptionNamesAsync("CheckpointLoaderSimple", "ckpt_name", cancellationToken);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Get a list of strings representing available sampler names
|
|
|
|
/// </summary>
|
|
|
|
public Task<List<string>?> GetSamplerNamesAsync(CancellationToken cancellationToken = default)
|
|
|
|
{
|
|
|
|
return GetNodeOptionNamesAsync("KSampler", "sampler_name", cancellationToken);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Get a list of strings representing available options of a given node
|
|
|
|
/// </summary>
|
|
|
|
public async Task<List<string>?> GetNodeOptionNamesAsync(
|
|
|
|
string nodeName,
|
|
|
|
string optionName,
|
|
|
|
CancellationToken cancellationToken = default
|
|
|
|
)
|
|
|
|
{
|
|
|
|
var response = await comfyApi
|
|
|
|
.GetObjectInfo(nodeName, cancellationToken)
|
|
|
|
.ConfigureAwait(false);
|
|
|
|
|
|
|
|
var info = response[nodeName];
|
|
|
|
return info.Input.GetRequiredValueAsNestedList(optionName);
|
|
|
|
}
|
|
|
|
|
|
|
|
protected override void Dispose(bool disposing)
|
|
|
|
{
|
|
|
|
if (isDisposed)
|
|
|
|
return;
|
|
|
|
webSocketClient.Dispose();
|
|
|
|
isDisposed = true;
|
|
|
|
}
|
|
|
|
}
|