Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

149 lines
4.5 KiB

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Packages;
namespace StabilityMatrix.Avalonia.Services;
/// <summary>
/// Manager for the current inference client
/// Has observable shared properties for shared info like model names
/// </summary>
public partial class InferenceClientManager : ObservableObject, IInferenceClientManager
{
private readonly ILogger<InferenceClientManager> logger;
private readonly IApiFactory apiFactory;
[ObservableProperty, NotifyPropertyChangedFor(nameof(IsConnected))]
private ComfyClient? client;
[MemberNotNullWhen(true, nameof(Client))]
public bool IsConnected => Client is not null;
[ObservableProperty]
private IReadOnlyCollection<string>? modelNames;
[ObservableProperty]
private IReadOnlyCollection<ComfySampler>? samplers;
[ObservableProperty]
private IReadOnlyCollection<ComfyUpscaler>? upscalers;
public InferenceClientManager(ILogger<InferenceClientManager> logger, IApiFactory apiFactory)
{
this.logger = logger;
this.apiFactory = apiFactory;
}
private async Task LoadSharedPropertiesAsync()
{
if (!IsConnected)
throw new InvalidOperationException("Client is not connected");
ModelNames = await Client.GetModelNamesAsync();
// Fetch sampler names from KSampler node
var samplerNames = await Client.GetSamplerNamesAsync();
Samplers = samplerNames?.Select(name => new ComfySampler(name)).ToImmutableArray();
// Upscalers is latent and esrgan combined
var upscalerBuilder = ImmutableArray.CreateBuilder<ComfyUpscaler>();
// Add latent upscale methods from LatentUpscale node
var latentUpscalerNames = await Client.GetNodeOptionNamesAsync(
"LatentUpscale",
"upscale_method");
if (latentUpscalerNames is not null)
{
upscalerBuilder.AddRange(latentUpscalerNames.Select(
s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)));
}
logger.LogTrace("Loaded latent upscale methods: {@Upscalers}", latentUpscalerNames);
// Add Model upscale methods
var modelUpscalerNames = await Client.GetNodeOptionNamesAsync(
"UpscaleModelLoader",
"model_name");
if (modelUpscalerNames is not null)
{
upscalerBuilder.AddRange(modelUpscalerNames.Select(
s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)));
}
logger.LogTrace("Loaded model upscale methods: {@Upscalers}", modelUpscalerNames);
Upscalers = upscalerBuilder.ToImmutable();
}
protected void ClearSharedProperties()
{
ModelNames = null;
Samplers = null;
Upscalers = null;
}
public async Task ConnectAsync()
{
if (IsConnected)
return;
var tempClient = new ComfyClient(apiFactory, new Uri("http://127.0.0.1:8188"));
await tempClient.ConnectAsync();
Client = tempClient;
await LoadSharedPropertiesAsync();
}
public async Task ConnectAsync(PackagePair packagePair)
{
if (IsConnected)
return;
if (packagePair.BasePackage is not ComfyUI)
{
throw new ArgumentException("Base package is not ComfyUI", nameof(packagePair));
}
var tempClient = new ComfyClient(apiFactory, new Uri("http://127.0.0.1:8188"));
// Add output dir if available
if (packagePair.InstalledPackage.FullPath is { } path)
{
tempClient.OutputImagesDir = new DirectoryPath(path, "output");
}
await tempClient.ConnectAsync();
Client = tempClient;
await LoadSharedPropertiesAsync();
}
public async Task CloseAsync()
{
if (!IsConnected)
return;
await Client.CloseAsync();
Client = null;
ClearSharedProperties();
}
public void Dispose()
{
Client?.Dispose();
Client = null;
GC.SuppressFinalize(this);
}
~InferenceClientManager()
{
Dispose();
}
}