using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DynamicData;
using DynamicData.Binding;
using Microsoft.Extensions.Logging;
using SkiaSharp;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.Services;
///
/// Manager for the current inference client
/// Has observable shared properties for shared info like model names
///
public partial class InferenceClientManager : ObservableObject, IInferenceClientManager
{
private readonly ILogger logger;
private readonly IApiFactory apiFactory;
private readonly IModelIndexService modelIndexService;
private readonly ISettingsManager settingsManager;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsConnected), nameof(CanUserConnect))]
private ComfyClient? client;
[MemberNotNullWhen(true, nameof(Client))]
public bool IsConnected => Client is not null;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(CanUserConnect))]
private bool isConnecting;
///
public bool CanUserConnect => !IsConnected && !IsConnecting;
///
public bool CanUserDisconnect => IsConnected && !IsConnecting;
private readonly SourceCache modelsSource = new(p => p.GetId());
public IObservableCollection Models { get; } =
new ObservableCollectionExtended();
private readonly SourceCache vaeModelsSource = new(p => p.GetId());
private readonly SourceCache vaeModelsDefaults = new(p => p.GetId());
public IObservableCollection VaeModels { get; } =
new ObservableCollectionExtended();
private readonly SourceCache samplersSource = new(p => p.Name);
public IObservableCollection Samplers { get; } =
new ObservableCollectionExtended();
private readonly SourceCache modelUpscalersSource = new(p => p.Name);
private readonly SourceCache latentUpscalersSource = new(p => p.Name);
private readonly SourceCache downloadableUpscalersSource =
new(p => p.Name);
public IObservableCollection Upscalers { get; } =
new ObservableCollectionExtended();
private readonly SourceCache schedulersSource = new(p => p.Name);
public IObservableCollection Schedulers { get; } =
new ObservableCollectionExtended();
public InferenceClientManager(
ILogger logger,
IApiFactory apiFactory,
IModelIndexService modelIndexService,
ISettingsManager settingsManager
)
{
this.logger = logger;
this.apiFactory = apiFactory;
this.modelIndexService = modelIndexService;
this.settingsManager = settingsManager;
modelsSource.Connect().DeferUntilLoaded().Bind(Models).Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
vaeModelsDefaults
.Connect()
.Or(vaeModelsSource.Connect())
.DeferUntilLoaded()
.Bind(VaeModels)
.Subscribe();
samplersSource.Connect().DeferUntilLoaded().Bind(Samplers).Subscribe();
latentUpscalersSource
.Connect()
.Or(modelUpscalersSource.Connect())
.Or(downloadableUpscalersSource.Connect())
.DeferUntilLoaded()
.Bind(Upscalers)
.Subscribe();
schedulersSource.Connect().DeferUntilLoaded().Bind(Schedulers).Subscribe();
settingsManager.RegisterOnLibraryDirSet(_ =>
{
Dispatcher.UIThread.Post(ResetSharedProperties, DispatcherPriority.Background);
});
EventManager.Instance.ModelIndexChanged += (_, _) =>
{
logger.LogDebug("Model index changed, reloading shared properties for Inference");
if (!settingsManager.IsLibraryDirSet)
return;
if (IsConnected)
{
LoadSharedPropertiesAsync()
.SafeFireAndForget(
onException: ex => logger.LogError(ex, "Error loading shared properties")
);
}
else
{
ResetSharedProperties();
}
};
}
private async Task LoadSharedPropertiesAsync()
{
if (!IsConnected)
throw new InvalidOperationException("Client is not connected");
if (await Client.GetModelNamesAsync() is { } modelNames)
{
modelsSource.EditDiff(
modelNames.Select(HybridModelFile.FromRemote),
HybridModelFile.Comparer
);
}
// Fetch sampler names from KSampler node
if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{
samplersSource.EditDiff(
samplerNames.Select(name => new ComfySampler(name)),
ComfySampler.Comparer
);
}
// Upscalers is latent and esrgan combined
// Add latent upscale methods from LatentUpscale node
if (
await Client.GetNodeOptionNamesAsync("LatentUpscale", "upscale_method") is
{ } latentUpscalerNames
)
{
latentUpscalersSource.EditDiff(
latentUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)),
ComfyUpscaler.Comparer
);
logger.LogTrace("Loaded latent upscale methods: {@Upscalers}", latentUpscalerNames);
}
// Add Model upscale methods
if (
await Client.GetNodeOptionNamesAsync("UpscaleModelLoader", "model_name") is
{ } modelUpscalerNames
)
{
modelUpscalersSource.EditDiff(
modelUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)),
ComfyUpscaler.Comparer
);
logger.LogTrace("Loaded model upscale methods: {@Upscalers}", modelUpscalerNames);
}
// Add scheduler names from Scheduler node
if (await Client.GetNodeOptionNamesAsync("KSampler", "scheduler") is { } schedulerNames)
{
schedulersSource.EditDiff(
schedulerNames.Select(s => new ComfyScheduler(s)),
ComfyScheduler.Comparer
);
logger.LogTrace("Loaded scheduler methods: {@Schedulers}", schedulerNames);
}
}
///
/// Clears shared properties and sets them to local defaults
///
private void ResetSharedProperties()
{
// Load local models
modelsSource.EditDiff(
modelIndexService
.GetFromModelIndex(SharedFolderType.StableDiffusion)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
// Load local VAE models
vaeModelsSource.EditDiff(
modelIndexService
.GetFromModelIndex(SharedFolderType.VAE)
.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
samplersSource.EditDiff(ComfySampler.Defaults, ComfySampler.Comparer);
latentUpscalersSource.EditDiff(ComfyUpscaler.Defaults, ComfyUpscaler.Comparer);
schedulersSource.EditDiff(ComfyScheduler.Defaults, ComfyScheduler.Comparer);
// Load Upscalers
modelUpscalersSource.EditDiff(
modelIndexService
.GetFromModelIndex(
SharedFolderType.ESRGAN | SharedFolderType.RealESRGAN | SharedFolderType.SwinIR
)
.Select(m => new ComfyUpscaler(m.FileName, ComfyUpscalerType.ESRGAN)),
ComfyUpscaler.Comparer
);
// Remote upscalers
var remoteUpscalers = ComfyUpscaler.DefaultDownloadableModels.Where(
u => !modelUpscalersSource.Lookup(u.Name).HasValue
);
downloadableUpscalersSource.EditDiff(remoteUpscalers, ComfyUpscaler.Comparer);
}
///
public async Task CopyImageToInputAsync(
FilePath imageFile,
CancellationToken cancellationToken = default
)
{
if (!IsConnected)
return;
if (Client.InputImagesDir is not { } inputImagesDir)
{
throw new InvalidOperationException("InputImagesDir is null");
}
var inferenceInputs = inputImagesDir.JoinDir("Inference");
inferenceInputs.Create();
var destination = inferenceInputs.JoinFile(imageFile.Name);
// Read to SKImage then write to file, to prevent errors from metadata
await Task.Run(
() =>
{
using var imageStream = imageFile.Info.OpenRead();
using var image = SKImage.FromEncodedData(imageStream);
using var destinationStream = destination.Info.OpenWrite();
image.Encode(SKEncodedImageFormat.Png, 100).SaveTo(destinationStream);
},
cancellationToken
);
}
///
public async Task WriteImageToInputAsync(
ImageSource imageSource,
CancellationToken cancellationToken = default
)
{
if (!IsConnected)
return;
if (Client.InputImagesDir is not { } inputImagesDir)
{
throw new InvalidOperationException("InputImagesDir is null");
}
var inferenceInputs = inputImagesDir.JoinDir("Inference");
inferenceInputs.Create();
}
private async Task ConnectAsyncImpl(Uri uri, CancellationToken cancellationToken = default)
{
if (IsConnected)
return;
IsConnecting = true;
try
{
logger.LogDebug("Connecting to {@Uri}...", uri);
var tempClient = new ComfyClient(apiFactory, uri);
await tempClient.ConnectAsync(cancellationToken);
Client = tempClient;
logger.LogDebug("Connected to {@Uri}", uri);
await LoadSharedPropertiesAsync();
}
finally
{
IsConnecting = false;
}
}
///
public Task ConnectAsync(CancellationToken cancellationToken = default)
{
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
}
///
public async Task ConnectAsync(
PackagePair packagePair,
CancellationToken cancellationToken = default
)
{
if (IsConnected)
return;
if (packagePair.BasePackage is not ComfyUI comfyPackage)
{
throw new ArgumentException("Base package is not ComfyUI", nameof(packagePair));
}
// Setup image folder links
await comfyPackage.SetupInferenceOutputFolderLinks(
packagePair.InstalledPackage.FullPath
?? throw new InvalidOperationException("Package does not have a Path")
);
// Get user defined host and port
var host = packagePair.InstalledPackage.GetLaunchArgsHost();
if (string.IsNullOrWhiteSpace(host))
{
host = "127.0.0.1";
}
host = host.Replace("localhost", "127.0.0.1");
var port = packagePair.InstalledPackage.GetLaunchArgsPort();
if (string.IsNullOrWhiteSpace(port))
{
port = "8188";
}
var uri = new UriBuilder("http", host, int.Parse(port)).Uri;
await ConnectAsyncImpl(uri, cancellationToken);
var packageDir = new DirectoryPath(packagePair.InstalledPackage.FullPath);
// Set package paths
Client!.OutputImagesDir = packageDir.JoinDir("output");
Client!.InputImagesDir = packageDir.JoinDir("input");
}
public async Task CloseAsync()
{
if (!IsConnected)
return;
await Client.CloseAsync();
Client = null;
ResetSharedProperties();
}
public void Dispose()
{
Client?.Dispose();
Client = null;
GC.SuppressFinalize(this);
}
~InferenceClientManager()
{
Dispose();
}
}