Browse Source

Add Extension prompt checking

pull/495/head
Ionite 9 months ago
parent
commit
49398c3386
No known key found for this signature in database
  1. 101
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  2. 1
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs
  3. 1
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

101
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -15,11 +15,15 @@ using Avalonia.Controls.Notifications;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using ExifLibrary;
using FluentAvalonia.UI.Controls;
using KGySoft.CoreLibraries;
using Nito.Disposables.Internals;
using NLog;
using Refit;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
@ -35,8 +39,11 @@ using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Packages.Extensions;
using StabilityMatrix.Core.Models.Settings;
using StabilityMatrix.Core.Services;
using Windows.ApplicationModel;
using Notification = DesktopNotifications.Notification;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
@ -272,6 +279,15 @@ public abstract partial class InferenceGenerationViewModelBase
if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null");
// Only check extensions for first batch index
if (args.BatchIndex == 0)
{
if (!await CheckPromptExtensionsInstalled(args.Nodes))
{
throw new ValidationException("Prompt extensions not installed");
}
}
// Upload input images
await UploadInputImages(client);
@ -621,6 +637,90 @@ public abstract partial class InferenceGenerationViewModelBase
return ClientManager.IsConnected;
}
/// <summary>
/// Shows a dialog and return false if prompt required extensions not installed
/// </summary>
private async Task<bool> CheckPromptExtensionsInstalled(NodeDictionary nodeDictionary)
{
// Get prompt required extensions
// Just static for now but could do manifest lookup when we support custom workflows
var requiredExtensions = nodeDictionary
.ClassTypeRequiredExtensions.Values.SelectMany(x => x)
.ToHashSet();
// Skip if no extensions required
if (requiredExtensions.Count == 0)
{
return true;
}
// Get installed extensions
var localPackagePair = ClientManager.Client?.LocalServerPackage.Unwrap()!;
var manager = localPackagePair.BasePackage.ExtensionManager.Unwrap();
var localExtensions = (
await ((GitPackageExtensionManager)manager).GetInstalledExtensionsLiteAsync(
localPackagePair.InstalledPackage
)
).ToImmutableArray();
var missingExtensions = requiredExtensions
.Except(localExtensions.Select(ext => ext.GitRepositoryUrl).WhereNotNull())
.ToImmutableArray();
if (missingExtensions.Length == 0)
{
return true;
}
var dialog = DialogHelper.CreateMarkdownDialog(
$"#### The following extensions are required for this workflow:\n"
+ $"{string.Join("\n- ", missingExtensions)}",
"Install Required Extensions?"
);
dialog.IsPrimaryButtonEnabled = true;
dialog.DefaultButton = ContentDialogButton.Primary;
dialog.PrimaryButtonText =
$"{Resources.Action_Install} ({localPackagePair.InstalledPackage.DisplayName.ToRepr()} will restart)";
dialog.CloseButtonText = Resources.Action_Cancel;
if (await dialog.ShowAsync() == ContentDialogResult.Primary)
{
var manifestExtensionsMap = await manager.GetManifestExtensionsMapAsync(
manager.GetManifests(localPackagePair.InstalledPackage)
);
var steps = new List<IPackageStep>();
foreach (var missingExtensionUrl in missingExtensions)
{
if (!manifestExtensionsMap.TryGetValue(missingExtensionUrl, out var extension))
{
Logger.Warn(
"Extension {MissingExtensionUrl} not found in manifests",
missingExtensionUrl
);
continue;
}
steps.Add(new InstallExtensionStep(manager, localPackagePair.InstalledPackage, extension));
}
var runner = new PackageModificationRunner
{
ShowDialogOnStart = true,
ModificationCompleteTitle = "Extensions Installed",
ModificationCompleteMessage = "Finished installing required extensions"
};
EventManager.Instance.OnPackageInstallProgressAdded(runner);
runner.ExecuteSteps(steps).SafeFireAndForget();
}
return false;
}
/// <summary>
/// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery.
@ -683,6 +783,7 @@ public abstract partial class InferenceGenerationViewModelBase
public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList<string> OutputNodeNames { get; init; }
public int BatchIndex { get; init; }
public GenerationParameters? Parameters { get; init; }
public InferenceProjectDocument? Project { get; init; }
public bool ClearOutputImages { get; init; } = true;

1
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs

@ -200,6 +200,7 @@ public partial class InferenceImageToVideoViewModel
Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch
ClearOutputImages = i == 0
};

1
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -214,6 +214,7 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, I
Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch
ClearOutputImages = i == 0
};

Loading…
Cancel
Save