From 49398c33867417a5fec160a72b56741509724a2e Mon Sep 17 00:00:00 2001 From: Ionite Date: Wed, 21 Feb 2024 16:22:41 -0500 Subject: [PATCH] Add Extension prompt checking --- .../Base/InferenceGenerationViewModelBase.cs | 101 ++++++++++++++++++ .../InferenceImageToVideoViewModel.cs | 1 + .../InferenceTextToImageViewModel.cs | 1 + 3 files changed, 103 insertions(+) diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index f73fa6f5..3846f045 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -15,11 +15,15 @@ using Avalonia.Controls.Notifications; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using ExifLibrary; +using FluentAvalonia.UI.Controls; +using KGySoft.CoreLibraries; +using Nito.Disposables.Internals; using NLog; using Refit; using SkiaSharp; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; @@ -35,8 +39,11 @@ using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.PackageModification; +using StabilityMatrix.Core.Models.Packages.Extensions; using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Services; +using Windows.ApplicationModel; using Notification = DesktopNotifications.Notification; namespace StabilityMatrix.Avalonia.ViewModels.Base; @@ -272,6 +279,15 @@ public abstract partial class InferenceGenerationViewModelBase if (client.OutputImagesDir is null) throw new InvalidOperationException("OutputImagesDir is null"); + // Only check extensions for first batch index + if (args.BatchIndex == 0) + { + if (!await CheckPromptExtensionsInstalled(args.Nodes)) + { + throw new ValidationException("Prompt extensions not installed"); + } + } + // Upload input images await UploadInputImages(client); @@ -621,6 +637,90 @@ public abstract partial class InferenceGenerationViewModelBase return ClientManager.IsConnected; } + /// + /// Shows a dialog and return false if prompt required extensions not installed + /// + private async Task CheckPromptExtensionsInstalled(NodeDictionary nodeDictionary) + { + // Get prompt required extensions + // Just static for now but could do manifest lookup when we support custom workflows + var requiredExtensions = nodeDictionary + .ClassTypeRequiredExtensions.Values.SelectMany(x => x) + .ToHashSet(); + + // Skip if no extensions required + if (requiredExtensions.Count == 0) + { + return true; + } + + // Get installed extensions + var localPackagePair = ClientManager.Client?.LocalServerPackage.Unwrap()!; + var manager = localPackagePair.BasePackage.ExtensionManager.Unwrap(); + + var localExtensions = ( + await ((GitPackageExtensionManager)manager).GetInstalledExtensionsLiteAsync( + localPackagePair.InstalledPackage + ) + ).ToImmutableArray(); + + var missingExtensions = requiredExtensions + .Except(localExtensions.Select(ext => ext.GitRepositoryUrl).WhereNotNull()) + .ToImmutableArray(); + + if (missingExtensions.Length == 0) + { + return true; + } + + var dialog = DialogHelper.CreateMarkdownDialog( + $"#### The following extensions are required for this workflow:\n" + + $"{string.Join("\n- ", missingExtensions)}", + "Install Required Extensions?" + ); + + dialog.IsPrimaryButtonEnabled = true; + dialog.DefaultButton = ContentDialogButton.Primary; + dialog.PrimaryButtonText = + $"{Resources.Action_Install} ({localPackagePair.InstalledPackage.DisplayName.ToRepr()} will restart)"; + dialog.CloseButtonText = Resources.Action_Cancel; + + if (await dialog.ShowAsync() == ContentDialogResult.Primary) + { + var manifestExtensionsMap = await manager.GetManifestExtensionsMapAsync( + manager.GetManifests(localPackagePair.InstalledPackage) + ); + + var steps = new List(); + + foreach (var missingExtensionUrl in missingExtensions) + { + if (!manifestExtensionsMap.TryGetValue(missingExtensionUrl, out var extension)) + { + Logger.Warn( + "Extension {MissingExtensionUrl} not found in manifests", + missingExtensionUrl + ); + continue; + } + + steps.Add(new InstallExtensionStep(manager, localPackagePair.InstalledPackage, extension)); + } + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteTitle = "Extensions Installed", + ModificationCompleteMessage = "Finished installing required extensions" + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + runner.ExecuteSteps(steps).SafeFireAndForget(); + } + + return false; + } + /// /// Handles the preview image received event from the websocket. /// Updates the preview image in the image gallery. @@ -683,6 +783,7 @@ public abstract partial class InferenceGenerationViewModelBase public required ComfyClient Client { get; init; } public required NodeDictionary Nodes { get; init; } public required IReadOnlyList OutputNodeNames { get; init; } + public int BatchIndex { get; init; } public GenerationParameters? Parameters { get; init; } public InferenceProjectDocument? Project { get; init; } public bool ClearOutputImages { get; init; } = true; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs index e977bda9..b57e3406 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs @@ -200,6 +200,7 @@ public partial class InferenceImageToVideoViewModel Parameters = SaveStateToParameters(new GenerationParameters()), Project = InferenceProjectDocument.FromLoadable(this), FilesToTransfer = buildPromptArgs.FilesToTransfer, + BatchIndex = i, // Only clear output images on the first batch ClearOutputImages = i == 0 }; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 0caab3a8..12d0d5fc 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -214,6 +214,7 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, I Parameters = SaveStateToParameters(new GenerationParameters()), Project = InferenceProjectDocument.FromLoadable(this), FilesToTransfer = buildPromptArgs.FilesToTransfer, + BatchIndex = i, // Only clear output images on the first batch ClearOutputImages = i == 0 };