You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
431 lines
14 KiB
431 lines
14 KiB
using System.Collections.Concurrent; |
|
using System.Collections.Immutable; |
|
using System.Net.WebSockets; |
|
using System.Text.Json; |
|
using NLog; |
|
using Polly.Contrib.WaitAndRetry; |
|
using Refit; |
|
using StabilityMatrix.Core.Api; |
|
using StabilityMatrix.Core.Exceptions; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Models.Api.Comfy; |
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; |
|
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
using Websocket.Client; |
|
using Websocket.Client.Exceptions; |
|
using Yoh.Text.Json.NamingPolicies; |
|
|
|
namespace StabilityMatrix.Core.Inference; |
|
|
|
public class ComfyClient : InferenceClientBase |
|
{ |
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
|
|
private readonly WebsocketClient webSocketClient; |
|
private readonly IComfyApi comfyApi; |
|
private bool isDisposed; |
|
|
|
private JsonSerializerOptions jsonSerializerOptions = |
|
new() { PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower, }; |
|
|
|
// ReSharper disable once MemberCanBePrivate.Global |
|
public string ClientId { get; } = Guid.NewGuid().ToString(); |
|
|
|
public Uri BaseAddress { get; } |
|
|
|
/// <summary> |
|
/// Optional local path to output images. |
|
/// </summary> |
|
public DirectoryPath? OutputImagesDir { get; set; } |
|
|
|
/// <summary> |
|
/// Optional local path to input images. |
|
/// </summary> |
|
public DirectoryPath? InputImagesDir { get; set; } |
|
|
|
/// <summary> |
|
/// Dictionary of ongoing prompt execution tasks |
|
/// </summary> |
|
public ConcurrentDictionary<string, ComfyTask> PromptTasks { get; } = new(); |
|
|
|
/// <summary> |
|
/// Current running prompt task |
|
/// </summary> |
|
private ComfyTask? currentPromptTask; |
|
|
|
/// <summary> |
|
/// Event raised when a progress update is received from the server |
|
/// </summary> |
|
public event EventHandler<ComfyWebSocketProgressData>? ProgressUpdateReceived; |
|
|
|
/// <summary> |
|
/// Event raised when a status update is received from the server |
|
/// </summary> |
|
public event EventHandler<ComfyWebSocketStatusData>? StatusUpdateReceived; |
|
|
|
/// <summary> |
|
/// Event raised when a executing update is received from the server |
|
/// </summary> |
|
public event EventHandler<ComfyWebSocketExecutingData>? ExecutingUpdateReceived; |
|
|
|
/// <summary> |
|
/// Event raised when a preview image is received from the server |
|
/// </summary> |
|
public event EventHandler<ComfyWebSocketImageData>? PreviewImageReceived; |
|
|
|
public ComfyClient(IApiFactory apiFactory, Uri baseAddress) |
|
{ |
|
comfyApi = apiFactory.CreateRefitClient<IComfyApi>(baseAddress); |
|
BaseAddress = baseAddress; |
|
|
|
// Setup websocket client |
|
var wsUri = new UriBuilder(baseAddress) |
|
{ |
|
Scheme = "ws", |
|
Path = "/ws", |
|
Query = $"clientId={ClientId}" |
|
}.Uri; |
|
|
|
webSocketClient = new WebsocketClient(wsUri) |
|
{ |
|
Name = nameof(ComfyClient), |
|
ReconnectTimeout = TimeSpan.FromSeconds(30) |
|
}; |
|
|
|
webSocketClient.DisconnectionHappened.Subscribe( |
|
info => Logger.Info("Websocket Disconnected, ({Type})", info.Type) |
|
); |
|
webSocketClient.ReconnectionHappened.Subscribe( |
|
info => Logger.Info("Websocket Reconnected, ({Type})", info.Type) |
|
); |
|
|
|
webSocketClient.MessageReceived.Subscribe(OnMessageReceived); |
|
} |
|
|
|
private void OnMessageReceived(ResponseMessage message) |
|
{ |
|
switch (message.MessageType) |
|
{ |
|
case WebSocketMessageType.Text: |
|
HandleTextMessage(message.Text); |
|
break; |
|
case WebSocketMessageType.Binary: |
|
HandleBinaryMessage(message.Binary); |
|
break; |
|
case WebSocketMessageType.Close: |
|
Logger.Trace("Received ws close message: {Text}", message.Text); |
|
break; |
|
default: |
|
throw new ArgumentOutOfRangeException(nameof(message)); |
|
} |
|
} |
|
|
|
private void HandleTextMessage(string text) |
|
{ |
|
ComfyWebSocketResponse? json; |
|
try |
|
{ |
|
json = JsonSerializer.Deserialize<ComfyWebSocketResponse>(text, jsonSerializerOptions); |
|
} |
|
catch (JsonException e) |
|
{ |
|
Logger.Warn($"Failed to parse json {text} ({e}), skipping"); |
|
return; |
|
} |
|
|
|
if (json is null) |
|
{ |
|
Logger.Warn($"Could not parse json {text}, skipping"); |
|
return; |
|
} |
|
|
|
Logger.Trace("Received json message: (Type = {Type}, Data = {Data})", json.Type, json.Data); |
|
|
|
if (json.Type == ComfyWebSocketResponseType.Executing) |
|
{ |
|
var executingData = json.GetDataAsType<ComfyWebSocketExecutingData>( |
|
jsonSerializerOptions |
|
); |
|
if (executingData?.PromptId is null) |
|
{ |
|
Logger.Warn($"Could not parse executing data {json.Data}, skipping"); |
|
return; |
|
} |
|
|
|
// When Node property is null, it means the prompt has finished executing |
|
// remove the task from the dictionary and set the result |
|
if (executingData.Node is null) |
|
{ |
|
if (PromptTasks.TryRemove(executingData.PromptId, out var task)) |
|
{ |
|
task.RunningNode = null; |
|
task.SetResult(); |
|
currentPromptTask = null; |
|
} |
|
else |
|
{ |
|
Logger.Warn( |
|
$"Could not find task for prompt {executingData.PromptId}, skipping" |
|
); |
|
} |
|
} |
|
// Otherwise set the task's active node to the one received |
|
else |
|
{ |
|
if (PromptTasks.TryGetValue(executingData.PromptId, out var task)) |
|
{ |
|
task.RunningNode = executingData.Node; |
|
} |
|
} |
|
|
|
ExecutingUpdateReceived?.Invoke(this, executingData); |
|
} |
|
else if (json.Type == ComfyWebSocketResponseType.Status) |
|
{ |
|
var statusData = json.GetDataAsType<ComfyWebSocketStatusData>(jsonSerializerOptions); |
|
if (statusData is null) |
|
{ |
|
Logger.Warn($"Could not parse status data {json.Data}, skipping"); |
|
return; |
|
} |
|
|
|
StatusUpdateReceived?.Invoke(this, statusData); |
|
} |
|
else if (json.Type == ComfyWebSocketResponseType.Progress) |
|
{ |
|
var progressData = json.GetDataAsType<ComfyWebSocketProgressData>( |
|
jsonSerializerOptions |
|
); |
|
if (progressData is null) |
|
{ |
|
Logger.Warn($"Could not parse progress data {json.Data}, skipping"); |
|
return; |
|
} |
|
|
|
// Set for the current prompt task |
|
currentPromptTask?.OnProgressUpdate(progressData); |
|
|
|
ProgressUpdateReceived?.Invoke(this, progressData); |
|
} |
|
else if (json.Type == ComfyWebSocketResponseType.ExecutionError) |
|
{ |
|
if ( |
|
json.GetDataAsType<ComfyWebSocketExecutionErrorData>(jsonSerializerOptions) |
|
is not { } errorData |
|
) |
|
{ |
|
Logger.Warn($"Could not parse ExecutionError data {json.Data}, skipping"); |
|
return; |
|
} |
|
|
|
// Set error status |
|
if (PromptTasks.TryRemove(errorData.PromptId, out var task)) |
|
{ |
|
task.RunningNode = null; |
|
task.SetException( |
|
new ComfyNodeException |
|
{ |
|
ErrorData = errorData, |
|
JsonData = json.Data.ToString() |
|
} |
|
); |
|
currentPromptTask = null; |
|
} |
|
else |
|
{ |
|
Logger.Warn($"Could not find task for prompt {errorData.PromptId}, skipping"); |
|
} |
|
} |
|
else |
|
{ |
|
Logger.Warn($"Unknown message type {json.Type} ({json.Data}), skipping"); |
|
} |
|
} |
|
|
|
/// <summary> |
|
/// Parses binary data (previews) into image streams |
|
/// https://github.com/comfyanonymous/ComfyUI/blob/master/server.py#L518 |
|
/// </summary> |
|
private void HandleBinaryMessage(byte[] data) |
|
{ |
|
if (data is not { Length: > 4 }) |
|
{ |
|
Logger.Warn("The input data is null or not long enough."); |
|
return; |
|
} |
|
|
|
// The first 4 bytes is int32 of the message type |
|
// Subsequent 4 bytes following is int32 of the image format |
|
// The rest is the image data |
|
|
|
// Read the image type from the first 4 bytes of the data. |
|
// Python's struct.pack(">I", type_num) will pack the data as a big-endian unsigned int |
|
/*var typeBytes = new byte[4]; |
|
stream.Read(typeBytes, 0, 4); |
|
var imageType = BitConverter.ToInt32(typeBytes, 0);*/ |
|
|
|
/*if (!BitConverter.IsLittleEndian) |
|
{ |
|
Array.Reverse(typeBytes); |
|
}*/ |
|
|
|
PreviewImageReceived?.Invoke(this, new ComfyWebSocketImageData { ImageBytes = data[8..], }); |
|
} |
|
|
|
public override async Task ConnectAsync(CancellationToken cancellationToken = default) |
|
{ |
|
var delays = Backoff |
|
.DecorrelatedJitterBackoffV2(TimeSpan.FromMilliseconds(500), retryCount: 5) |
|
.ToImmutableArray(); |
|
|
|
foreach (var (i, retryDelay) in delays.Enumerate()) |
|
{ |
|
cancellationToken.ThrowIfCancellationRequested(); |
|
|
|
try |
|
{ |
|
await webSocketClient.StartOrFail().ConfigureAwait(false); |
|
return; |
|
} |
|
catch (WebsocketException e) |
|
{ |
|
Logger.Info( |
|
"Failed to connect to websocket, retrying in {RetryDelay} ({Message})", |
|
retryDelay, |
|
e.Message |
|
); |
|
|
|
if (i == delays.Length - 1) |
|
{ |
|
throw; |
|
} |
|
} |
|
} |
|
} |
|
|
|
public override async Task CloseAsync(CancellationToken cancellationToken = default) |
|
{ |
|
await webSocketClient |
|
.Stop(WebSocketCloseStatus.NormalClosure, string.Empty) |
|
.ConfigureAwait(false); |
|
} |
|
|
|
public async Task<ComfyTask> QueuePromptAsync( |
|
Dictionary<string, ComfyNode> nodes, |
|
CancellationToken cancellationToken = default |
|
) |
|
{ |
|
var request = new ComfyPromptRequest { ClientId = ClientId, Prompt = nodes }; |
|
var result = await comfyApi.PostPrompt(request, cancellationToken).ConfigureAwait(false); |
|
|
|
// Add task to dictionary and set it as the current task |
|
var task = new ComfyTask(result.PromptId); |
|
PromptTasks.TryAdd(result.PromptId, task); |
|
currentPromptTask = task; |
|
|
|
return task; |
|
} |
|
|
|
public async Task InterruptPromptAsync(CancellationToken cancellationToken = default) |
|
{ |
|
await comfyApi.PostInterrupt(cancellationToken).ConfigureAwait(false); |
|
|
|
// Set the current task to null, and remove it from the dictionary |
|
if (currentPromptTask is { } task) |
|
{ |
|
PromptTasks.TryRemove(task.Id, out _); |
|
task.TrySetCanceled(cancellationToken); |
|
task.Dispose(); |
|
currentPromptTask = null; |
|
} |
|
} |
|
|
|
// Upload images |
|
public Task<ComfyUploadImageResponse> UploadImageAsync( |
|
Stream image, |
|
string fileName, |
|
CancellationToken cancellationToken = default |
|
) |
|
{ |
|
var streamPart = new StreamPart(image, fileName); |
|
return comfyApi.PostUploadImage( |
|
streamPart, |
|
"true", |
|
"input", |
|
"Inference", |
|
cancellationToken |
|
); |
|
} |
|
|
|
public async Task<Dictionary<string, List<ComfyImage>?>> GetImagesForExecutedPromptAsync( |
|
string promptId, |
|
CancellationToken cancellationToken = default |
|
) |
|
{ |
|
// Get history for images |
|
var history = await comfyApi.GetHistory(promptId, cancellationToken).ConfigureAwait(false); |
|
|
|
// Get the current prompt history |
|
var current = history[promptId]; |
|
|
|
var dict = new Dictionary<string, List<ComfyImage>?>(); |
|
foreach (var (nodeKey, output) in current.Outputs) |
|
{ |
|
dict[nodeKey] = output.Images; |
|
} |
|
return dict; |
|
} |
|
|
|
public async Task<Stream> GetImageStreamAsync( |
|
ComfyImage comfyImage, |
|
CancellationToken cancellationToken = default |
|
) |
|
{ |
|
var response = await comfyApi |
|
.GetImage(comfyImage.FileName, comfyImage.SubFolder, comfyImage.Type, cancellationToken) |
|
.ConfigureAwait(false); |
|
return response; |
|
} |
|
|
|
/// <summary> |
|
/// Get a list of strings representing available model names |
|
/// </summary> |
|
public Task<List<string>?> GetModelNamesAsync(CancellationToken cancellationToken = default) |
|
{ |
|
return GetNodeOptionNamesAsync("CheckpointLoaderSimple", "ckpt_name", cancellationToken); |
|
} |
|
|
|
/// <summary> |
|
/// Get a list of strings representing available sampler names |
|
/// </summary> |
|
public Task<List<string>?> GetSamplerNamesAsync(CancellationToken cancellationToken = default) |
|
{ |
|
return GetNodeOptionNamesAsync("KSampler", "sampler_name", cancellationToken); |
|
} |
|
|
|
/// <summary> |
|
/// Get a list of strings representing available options of a given node |
|
/// </summary> |
|
public async Task<List<string>?> GetNodeOptionNamesAsync( |
|
string nodeName, |
|
string optionName, |
|
CancellationToken cancellationToken = default |
|
) |
|
{ |
|
var response = await comfyApi |
|
.GetObjectInfo(nodeName, cancellationToken) |
|
.ConfigureAwait(false); |
|
|
|
var info = response[nodeName]; |
|
return info.Input.GetRequiredValueAsNestedList(optionName); |
|
} |
|
|
|
protected override void Dispose(bool disposing) |
|
{ |
|
if (isDisposed) |
|
return; |
|
webSocketClient.Dispose(); |
|
isDisposed = true; |
|
} |
|
}
|
|
|