using System; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using DynamicData; using DynamicData.Binding; using Microsoft.Extensions.Logging; using SkiaSharp; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.Services; /// /// Manager for the current inference client /// Has observable shared properties for shared info like model names /// [Singleton(typeof(IInferenceClientManager))] public partial class InferenceClientManager : ObservableObject, IInferenceClientManager { private readonly ILogger logger; private readonly IApiFactory apiFactory; private readonly IModelIndexService modelIndexService; private readonly ISettingsManager settingsManager; private readonly ICompletionProvider completionProvider; [ObservableProperty] [NotifyPropertyChangedFor(nameof(IsConnected), nameof(CanUserConnect))] private ComfyClient? client; [MemberNotNullWhen(true, nameof(Client))] public bool IsConnected => Client is not null; [ObservableProperty] [NotifyPropertyChangedFor(nameof(CanUserConnect))] private bool isConnecting; /// public bool CanUserConnect => !IsConnected && !IsConnecting; /// public bool CanUserDisconnect => IsConnected && !IsConnecting; private readonly SourceCache modelsSource = new(p => p.GetId()); public IObservableCollection Models { get; } = new ObservableCollectionExtended(); private readonly SourceCache vaeModelsSource = new(p => p.GetId()); private readonly SourceCache vaeModelsDefaults = new(p => p.GetId()); public IObservableCollection VaeModels { get; } = new ObservableCollectionExtended(); private readonly SourceCache controlNetModelsSource = new(p => p.GetId()); private readonly SourceCache downloadableControlNetModelsSource = new(p => p.GetId()); public IObservableCollection ControlNetModels { get; } = new ObservableCollectionExtended(); private readonly SourceCache promptExpansionModelsSource = new(p => p.GetId()); private readonly SourceCache downloadablePromptExpansionModelsSource = new(p => p.GetId()); public IObservableCollection PromptExpansionModels { get; } = new ObservableCollectionExtended(); private readonly SourceCache samplersSource = new(p => p.Name); public IObservableCollection Samplers { get; } = new ObservableCollectionExtended(); private readonly SourceCache modelUpscalersSource = new(p => p.Name); private readonly SourceCache latentUpscalersSource = new(p => p.Name); private readonly SourceCache downloadableUpscalersSource = new(p => p.Name); public IObservableCollection Upscalers { get; } = new ObservableCollectionExtended(); private readonly SourceCache schedulersSource = new(p => p.Name); public IObservableCollection Schedulers { get; } = new ObservableCollectionExtended(); public InferenceClientManager( ILogger logger, IApiFactory apiFactory, IModelIndexService modelIndexService, ISettingsManager settingsManager, ICompletionProvider completionProvider ) { this.logger = logger; this.apiFactory = apiFactory; this.modelIndexService = modelIndexService; this.settingsManager = settingsManager; this.completionProvider = completionProvider; modelsSource .Connect() .SortBy( f => f.ShortDisplayName, SortDirection.Ascending, SortOptimisations.ComparesImmutableValuesOnly ) .DeferUntilLoaded() .Bind(Models) .Subscribe(); controlNetModelsSource .Connect() .Or(downloadableControlNetModelsSource.Connect()) .Sort( SortExpressionComparer .Ascending(f => f.Type) .ThenByAscending(f => f.ShortDisplayName) ) .DeferUntilLoaded() .Bind(ControlNetModels) .Subscribe(); promptExpansionModelsSource .Connect() .Or(downloadablePromptExpansionModelsSource.Connect()) .Sort( SortExpressionComparer .Ascending(f => f.Type) .ThenByAscending(f => f.ShortDisplayName) ) .DeferUntilLoaded() .Bind(PromptExpansionModels) .Subscribe(); vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default); vaeModelsDefaults.Connect().Or(vaeModelsSource.Connect()).Bind(VaeModels).Subscribe(); samplersSource.Connect().DeferUntilLoaded().Bind(Samplers).Subscribe(); latentUpscalersSource .Connect() .Or(modelUpscalersSource.Connect()) .Or(downloadableUpscalersSource.Connect()) .Sort(SortExpressionComparer.Ascending(f => f.Type).ThenByAscending(f => f.Name)) .Bind(Upscalers) .Subscribe(); schedulersSource.Connect().DeferUntilLoaded().Bind(Schedulers).Subscribe(); settingsManager.RegisterOnLibraryDirSet(_ => { Dispatcher.UIThread.Post(ResetSharedProperties, DispatcherPriority.Background); }); EventManager.Instance.ModelIndexChanged += (_, _) => { logger.LogDebug("Model index changed, reloading shared properties for Inference"); if (!settingsManager.IsLibraryDirSet) return; ResetSharedProperties(); if (IsConnected) { LoadSharedPropertiesAsync() .SafeFireAndForget( onException: ex => logger.LogError(ex, "Error loading shared properties") ); } }; } [MemberNotNull(nameof(Client))] private void EnsureConnected() { if (!IsConnected) throw new InvalidOperationException("Client is not connected"); } private async Task LoadSharedPropertiesAsync() { EnsureConnected(); // Get model names if (await Client.GetModelNamesAsync() is { } modelNames) { modelsSource.EditDiff(modelNames.Select(HybridModelFile.FromRemote), HybridModelFile.Comparer); } // Get control net model names if ( await Client.GetNodeOptionNamesAsync("ControlNetLoader", "control_net_name") is { } controlNetModelNames ) { controlNetModelsSource.EditDiff( controlNetModelNames.Select(HybridModelFile.FromRemote), HybridModelFile.Comparer ); } // Prompt Expansion indexing is local only // Fetch sampler names from KSampler node if (await Client.GetSamplerNamesAsync() is { } samplerNames) { samplersSource.EditDiff( samplerNames.Select(name => new ComfySampler(name)), ComfySampler.Comparer ); } // Upscalers is latent and esrgan combined // Add latent upscale methods from LatentUpscale node if ( await Client.GetNodeOptionNamesAsync("LatentUpscale", "upscale_method") is { } latentUpscalerNames ) { latentUpscalersSource.EditDiff( latentUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)), ComfyUpscaler.Comparer ); logger.LogTrace("Loaded latent upscale methods: {@Upscalers}", latentUpscalerNames); } // Add Model upscale methods if ( await Client.GetNodeOptionNamesAsync("UpscaleModelLoader", "model_name") is { } modelUpscalerNames ) { modelUpscalersSource.EditDiff( modelUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)), ComfyUpscaler.Comparer ); logger.LogTrace("Loaded model upscale methods: {@Upscalers}", modelUpscalerNames); } // Add scheduler names from Scheduler node if (await Client.GetNodeOptionNamesAsync("KSampler", "scheduler") is { } schedulerNames) { schedulersSource.Edit(updater => { updater.AddOrUpdate( schedulerNames .Where(n => !schedulersSource.Keys.Contains(n)) .Select(s => new ComfyScheduler(s)) ); }); logger.LogTrace("Loaded scheduler methods: {@Schedulers}", schedulerNames); } } /// /// Clears shared properties and sets them to local defaults /// private void ResetSharedProperties() { // Load local models modelsSource.EditDiff( modelIndexService .GetFromModelIndex(SharedFolderType.StableDiffusion) .Select(HybridModelFile.FromLocal), HybridModelFile.Comparer ); // Load local control net models controlNetModelsSource.EditDiff( modelIndexService .GetFromModelIndex(SharedFolderType.ControlNet) .Select(HybridModelFile.FromLocal), HybridModelFile.Comparer ); // Downloadable ControlNet models var downloadableControlNets = RemoteModels.ControlNetModels.Where( u => !controlNetModelsSource.Lookup(u.GetId()).HasValue ); downloadableControlNetModelsSource.EditDiff(downloadableControlNets, HybridModelFile.Comparer); // Load local prompt expansion models promptExpansionModelsSource.EditDiff( modelIndexService .GetFromModelIndex(SharedFolderType.PromptExpansion) .Select(HybridModelFile.FromLocal), HybridModelFile.Comparer ); // Downloadable PromptExpansion models downloadablePromptExpansionModelsSource.EditDiff( RemoteModels.PromptExpansionModels.Where( u => !promptExpansionModelsSource.Lookup(u.GetId()).HasValue ), HybridModelFile.Comparer ); // Load local VAE models vaeModelsSource.EditDiff( modelIndexService.GetFromModelIndex(SharedFolderType.VAE).Select(HybridModelFile.FromLocal), HybridModelFile.Comparer ); samplersSource.EditDiff(ComfySampler.Defaults, ComfySampler.Comparer); latentUpscalersSource.EditDiff(ComfyUpscaler.Defaults, ComfyUpscaler.Comparer); schedulersSource.EditDiff(ComfyScheduler.Defaults, ComfyScheduler.Comparer); // Load Upscalers modelUpscalersSource.EditDiff( modelIndexService .GetFromModelIndex( SharedFolderType.ESRGAN | SharedFolderType.RealESRGAN | SharedFolderType.SwinIR ) .Select(m => new ComfyUpscaler(m.FileName, ComfyUpscalerType.ESRGAN)), ComfyUpscaler.Comparer ); // Remote upscalers var remoteUpscalers = ComfyUpscaler.DefaultDownloadableModels.Where( u => !modelUpscalersSource.Lookup(u.Name).HasValue ); downloadableUpscalersSource.EditDiff(remoteUpscalers, ComfyUpscaler.Comparer); } /// public async Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default) { EnsureConnected(); if (image.LocalFile is not { } localFile) { throw new ArgumentException("Image is not a local file", nameof(image)); } var uploadName = await image.GetHashGuidFileNameAsync(); await using var stream = localFile.Info.OpenRead(); await Client.UploadImageAsync(stream, uploadName, cancellationToken); } /// public async Task CopyImageToInputAsync(FilePath imageFile, CancellationToken cancellationToken = default) { if (!IsConnected) return; if (Client.InputImagesDir is not { } inputImagesDir) { throw new InvalidOperationException("InputImagesDir is null"); } var inferenceInputs = inputImagesDir.JoinDir("Inference"); inferenceInputs.Create(); var destination = inferenceInputs.JoinFile(imageFile.Name); // Read to SKImage then write to file, to prevent errors from metadata await Task.Run( () => { using var imageStream = imageFile.Info.OpenRead(); using var image = SKImage.FromEncodedData(imageStream); using var destinationStream = destination.Info.OpenWrite(); image.Encode(SKEncodedImageFormat.Png, 100).SaveTo(destinationStream); }, cancellationToken ); } /// public async Task WriteImageToInputAsync( ImageSource imageSource, CancellationToken cancellationToken = default ) { if (!IsConnected) return; if (Client.InputImagesDir is not { } inputImagesDir) { throw new InvalidOperationException("InputImagesDir is null"); } var inferenceInputs = inputImagesDir.JoinDir("Inference"); inferenceInputs.Create(); } [MemberNotNull(nameof(Client))] private async Task ConnectAsyncImpl(Uri uri, CancellationToken cancellationToken = default) { if (IsConnected) return; IsConnecting = true; try { logger.LogDebug("Connecting to {@Uri}...", uri); var tempClient = new ComfyClient(apiFactory, uri); await tempClient.ConnectAsync(cancellationToken); logger.LogDebug("Connected to {@Uri}", uri); Client = tempClient; await LoadSharedPropertiesAsync(); } catch (Exception) { Client = null; throw; } finally { IsConnecting = false; } } /// public Task ConnectAsync(CancellationToken cancellationToken = default) { return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken); } private async Task MigrateLinksIfNeeded(PackagePair packagePair) { if (packagePair.InstalledPackage.FullPath is not { } packagePath) { throw new ArgumentException("Package path is null", nameof(packagePair)); } var inferenceDir = settingsManager.ImagesInferenceDirectory; inferenceDir.Create(); // For locally installed packages only // Delete ./output/Inference var legacyInferenceLinkDir = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir( "output", "Inference" ); if (legacyInferenceLinkDir.Exists) { logger.LogInformation("Deleting legacy inference link at {LegacyDir}", legacyInferenceLinkDir); if (legacyInferenceLinkDir.IsSymbolicLink) { await legacyInferenceLinkDir.DeleteAsync(false); } else { logger.LogWarning( "Legacy inference link at {LegacyDir} is not a symbolic link, skipping", legacyInferenceLinkDir ); } } } /// public async Task ConnectAsync(PackagePair packagePair, CancellationToken cancellationToken = default) { if (IsConnected) return; if (packagePair.BasePackage is not ComfyUI comfyPackage) { throw new ArgumentException("Base package is not ComfyUI", nameof(packagePair)); } // Setup completion provider completionProvider .Setup() .SafeFireAndForget(ex => { logger.LogError(ex, "Error setting up completion provider"); }); await MigrateLinksIfNeeded(packagePair); // Get user defined host and port var host = packagePair.InstalledPackage.GetLaunchArgsHost(); if (string.IsNullOrWhiteSpace(host)) { host = "127.0.0.1"; } host = host.Replace("localhost", "127.0.0.1"); var port = packagePair.InstalledPackage.GetLaunchArgsPort(); if (string.IsNullOrWhiteSpace(port)) { port = "8188"; } var uri = new UriBuilder("http", host, int.Parse(port)).Uri; await ConnectAsyncImpl(uri, cancellationToken); Client.LocalServerPackage = packagePair; Client.LocalServerPath = packagePair.InstalledPackage.FullPath!; } public async Task CloseAsync() { if (!IsConnected) return; await Client.CloseAsync(); Client = null; ResetSharedProperties(); } public void Dispose() { Client?.Dispose(); Client = null; GC.SuppressFinalize(this); } ~InferenceClientManager() { Dispose(); } }