Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

893 lines
31 KiB

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls.Notifications;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using ExifLibrary;
using FluentAvalonia.UI.Controls;
using Microsoft.Extensions.DependencyInjection;
using NLog;
using Refit;
using Semver;
using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Packages.Extensions;
using StabilityMatrix.Core.Models.Settings;
using StabilityMatrix.Core.Services;
using Notification = DesktopNotifications.Notification;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
/// <summary>
/// Abstract base class for tab view models that generate images using ClientManager.
/// This includes a progress reporter, image output view model, and generation virtual methods.
/// </summary>
[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
public abstract partial class InferenceGenerationViewModelBase
: InferenceTabViewModelBase,
IImageGalleryComponent
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
[JsonPropertyName("ImageGallery")]
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
[JsonIgnore]
public ImageFolderCardViewModel ImageFolderCardViewModel { get; }
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
[JsonIgnore]
public IInferenceClientManager ClientManager { get; }
/// <inheritdoc />
protected InferenceGenerationViewModelBase(
ServiceManager<ViewModelBase> vmFactory,
IInferenceClientManager inferenceClientManager,
INotificationService notificationService,
ISettingsManager settingsManager
)
: base(notificationService)
{
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.vmFactory = vmFactory;
ClientManager = inferenceClientManager;
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
ImageFolderCardViewModel = vmFactory.Get<ImageFolderCardViewModel>();
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
/// <summary>
/// Write an image to the default output folder
/// </summary>
protected Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false,
string fileExtension = "png"
)
{
var defaultOutputDir = settingsManager.ImagesInferenceDirectory;
defaultOutputDir.Create();
return WriteOutputImageAsync(
imageStream,
defaultOutputDir,
args,
batchNum,
batchTotal,
isGrid,
fileExtension
);
}
/// <summary>
/// Write an image to an output folder
/// </summary>
protected async Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
DirectoryPath outputDir,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false,
string fileExtension = "png"
)
{
var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat;
var formatProvider = new FileNameFormatProvider
{
GenerationParameters = args.Parameters,
ProjectType = args.Project?.ProjectType,
ProjectName = ProjectFile?.NameWithoutExtension
};
// Parse to format
if (
string.IsNullOrEmpty(formatTemplateStr)
|| !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format)
)
{
// Fallback to default
Logger.Warn(
"Failed to parse format template: {FormatTemplate}, using default",
formatTemplateStr
);
format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider);
}
if (isGrid)
{
format = format.WithGridPrefix();
}
if (batchNum >= 1 && batchTotal > 1)
{
format = format.WithBatchPostFix(batchNum, batchTotal);
}
var fileName = format.GetFileName();
var file = outputDir.JoinFile($"{fileName}.{fileExtension}");
// Until the file is free, keep adding _{i} to the end
for (var i = 0; i < 100; i++)
{
if (!file.Exists)
break;
file = outputDir.JoinFile($"{fileName}_{i + 1}.{fileExtension}");
}
// If that fails, append an 7-char uuid
if (file.Exists)
{
var uuid = Guid.NewGuid().ToString("N")[..7];
file = outputDir.JoinFile($"{fileName}_{uuid}.{fileExtension}");
}
if (file.Info.DirectoryName != null)
{
Directory.CreateDirectory(file.Info.DirectoryName);
}
await using var fileStream = file.Info.OpenWrite();
await imageStream.CopyToAsync(fileStream);
return file;
}
/// <summary>
/// Builds the image generation prompt
/// </summary>
protected virtual void BuildPrompt(BuildPromptEventArgs args) { }
/// <summary>
/// Uploads files required for the prompt
/// </summary>
protected virtual async Task UploadPromptFiles(
IEnumerable<(string SourcePath, string DestinationRelativePath)> files,
ComfyClient client
)
{
foreach (var (sourcePath, destinationRelativePath) in files)
{
Logger.Debug(
"Uploading prompt file {SourcePath} to relative path {DestinationPath}",
sourcePath,
destinationRelativePath
);
await client.UploadFileAsync(sourcePath, destinationRelativePath);
}
}
/// <summary>
/// Gets ImageSources that need to be uploaded as inputs
/// </summary>
protected virtual IEnumerable<ImageSource> GetInputImages()
{
return Enumerable.Empty<ImageSource>();
}
protected async Task UploadInputImages(ComfyClient client)
{
foreach (var image in GetInputImages())
{
if (image.LocalFile is { } localFile)
{
var uploadName = await image.GetHashGuidFileNameAsync();
Logger.Debug("Uploading image {FileName} as {UploadName}", localFile.Name, uploadName);
// For pngs, strip metadata since Pillow can't handle some valid files?
if (localFile.Info.Extension.Equals(".png", StringComparison.OrdinalIgnoreCase))
{
var bytes = PngDataHelper.RemoveMetadata(await localFile.ReadAllBytesAsync());
using var stream = new MemoryStream(bytes);
await client.UploadImageAsync(stream, uploadName);
}
else
{
await using var stream = localFile.Info.OpenRead();
await client.UploadImageAsync(stream, uploadName);
}
}
}
}
/// <summary>
/// Runs a generation task
/// </summary>
/// <exception cref="InvalidOperationException">Thrown if args.Parameters or args.Project are null</exception>
protected async Task RunGeneration(ImageGenerationEventArgs args, CancellationToken cancellationToken)
{
var client = args.Client;
var nodes = args.Nodes;
// Checks
if (args.Parameters is null)
throw new InvalidOperationException("Parameters is null");
if (args.Project is null)
throw new InvalidOperationException("Project is null");
if (args.OutputNodeNames.Count == 0)
throw new InvalidOperationException("OutputNodeNames is empty");
if (client.OutputImagesDir is null)
throw new InvalidOperationException("OutputImagesDir is null");
// Only check extensions for first batch index
if (args.BatchIndex == 0)
{
if (!await CheckPromptExtensionsInstalled(args.Nodes))
{
throw new ValidationException("Prompt extensions not installed");
}
}
// Upload input images
await UploadInputImages(client);
// Upload required files
await UploadPromptFiles(args.FilesToTransfer, client);
// Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived;
// Register to interrupt if user cancels
var promptInterrupt = cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget(ex =>
{
Logger.Warn(ex, "Error while interrupting prompt");
});
});
ComfyTask? promptTask = null;
try
{
var timer = Stopwatch.StartNew();
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
}
catch (ApiException e)
{
Logger.Warn(e, "Api exception while queuing prompt");
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
// Delay attaching running node change handler to not show indeterminate progress
// if progress updates are received before the prompt starts
Task.Run(
async () =>
{
try
{
var delayTime = 250 - (int)timer.ElapsedMilliseconds;
if (delayTime > 0)
{
await Task.Delay(delayTime, cancellationToken);
}
// ReSharper disable once AccessToDisposedClosure
AttachRunningNodeChangedHandler(promptTask);
}
catch (TaskCanceledException) { }
},
cancellationToken
)
.SafeFireAndForget();
// Wait for prompt to finish
try
{
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Debug($"Prompt task {promptTask.Id} finished");
}
catch (ComfyNodeException e)
{
Logger.Warn(e, "Comfy node exception while queuing prompt");
await DialogHelper
.CreateJsonDialog(e.JsonData, "Comfy Error", "Node execution encountered an error")
.ShowAsync();
return;
}
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(promptTask.Id, cancellationToken);
if (imageOutputs.Values.All(images => images is null or { Count: 0 }))
{
// No images match
notificationService.Show(
"No output",
"Did not receive any output images",
NotificationType.Warning
);
return;
}
// Disable cancellation
await promptInterrupt.DisposeAsync();
if (args.ClearOutputImages)
{
ImageGalleryCardViewModel.ImageSources.Clear();
}
var outputImages = await ProcessAllOutputImages(imageOutputs, args);
var notificationImage = outputImages.FirstOrDefault()?.LocalFile;
await notificationService.ShowAsync(
NotificationKey.Inference_PromptCompleted,
new Notification
{
Title = "Prompt Completed",
Body = $"Prompt [{promptTask.Id[..7].ToLower()}] completed successfully",
BodyImagePath = notificationImage?.FullPath
}
);
}
finally
{
// Disconnect progress handler
client.PreviewImageReceived -= OnPreviewImageReceived;
// Clear progress
OutputProgress.ClearProgress();
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
// Cleanup tasks
promptTask?.Dispose();
}
}
private async Task<IEnumerable<ImageSource>> ProcessAllOutputImages(
IReadOnlyDictionary<string, List<ComfyImage>?> images,
ImageGenerationEventArgs args
)
{
var results = new List<ImageSource>();
foreach (var (nodeName, imageList) in images)
{
if (imageList is null)
{
Logger.Warn("No images for node {NodeName}", nodeName);
continue;
}
results.AddRange(await ProcessOutputImages(imageList, args, nodeName.Replace('_', ' ')));
}
return results;
}
/// <summary>
/// Handles image output metadata for generation runs
/// </summary>
private async Task<List<ImageSource>> ProcessOutputImages(
IReadOnlyCollection<ComfyImage> images,
ImageGenerationEventArgs args,
string? imageLabel = null
)
{
var client = args.Client;
// Write metadata to images
var outputImagesBytes = new List<byte[]>();
var outputImages = new List<ImageSource>();
foreach (var (i, comfyImage) in images.Enumerate())
{
Logger.Debug("Downloading image: {FileName}", comfyImage.FileName);
var imageStream = await client.GetImageStreamAsync(comfyImage);
using var ms = new MemoryStream();
await imageStream.CopyToAsync(ms);
var imageArray = ms.ToArray();
outputImagesBytes.Add(imageArray);
var parameters = args.Parameters!;
var project = args.Project!;
// Lock seed
project.TryUpdateModel<SeedCardModel>("Seed", model => model with { IsRandomizeEnabled = false });
// Seed and batch override for batches
if (images.Count > 1 && project.ProjectType is InferenceProjectType.TextToImage)
{
project = (InferenceProjectDocument)project.Clone();
// Set batch size indexes
project.TryUpdateModel(
"BatchSize",
node =>
{
node[nameof(BatchSizeCardViewModel.BatchCount)] = 1;
node[nameof(BatchSizeCardViewModel.IsBatchIndexEnabled)] = true;
node[nameof(BatchSizeCardViewModel.BatchIndex)] = i + 1;
return node;
}
);
}
if (comfyImage.FileName.EndsWith(".png"))
{
var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project);
// Write using generated name
var filePath = await WriteOutputImageAsync(
new MemoryStream(bytesWithMetadata),
args,
i + 1,
images.Count
);
outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
else if (comfyImage.FileName.EndsWith(".webp"))
{
var opts = new JsonSerializerOptions
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = { new JsonStringEnumConverter() }
};
var paramsJson = JsonSerializer.Serialize(parameters, opts);
var smProject = JsonSerializer.Serialize(project, opts);
var metadata = new Dictionary<ExifTag, string>
{
{ ExifTag.ImageDescription, paramsJson },
{ ExifTag.Software, smProject }
};
var bytesWithMetadata = ImageMetadata.AddMetadataToWebp(imageArray, metadata);
// Write using generated name
var filePath = await WriteOutputImageAsync(
new MemoryStream(bytesWithMetadata.ToArray()),
args,
i + 1,
images.Count,
fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "")
);
outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
else
{
// Write using generated name
var filePath = await WriteOutputImageAsync(
new MemoryStream(imageArray),
args,
i + 1,
images.Count,
fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "")
);
outputImages.Add(new ImageSource(filePath) { Label = imageLabel });
EventManager.Instance.OnImageFileAdded(filePath);
}
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray();
var project = args.Project!;
// Lock seed
project.TryUpdateModel<SeedCardModel>("Seed", model => model with { IsRandomizeEnabled = false });
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(gridBytes, args.Parameters!, args.Project!);
// Save to disk
var gridPath = await WriteOutputImageAsync(
new MemoryStream(gridBytesWithMetadata),
args,
isGrid: true
);
// Insert to start of images
var gridImage = new ImageSource(gridPath);
outputImages.Insert(0, gridImage);
EventManager.Instance.OnImageFileAdded(gridPath);
}
foreach (var img in outputImages)
{
// Preload
await img.GetBitmapAsync();
// Add images
ImageGalleryCardViewModel.ImageSources.Add(img);
}
return outputImages;
}
/// <summary>
/// Implementation for Generate Image
/// </summary>
protected virtual Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
/// <summary>
/// Command for the Generate Image button
/// </summary>
/// <param name="options">Optional overrides (side buttons)</param>
/// <param name="cancellationToken">Cancellation token</param>
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
GenerateFlags options = default,
CancellationToken cancellationToken = default
)
{
var overrides = GenerateOverrides.FromFlags(options);
try
{
await GenerateImageImpl(overrides, cancellationToken);
}
catch (OperationCanceledException)
{
Logger.Debug("Image Generation Canceled");
}
catch (ValidationException e)
{
Logger.Debug("Image Generation Validation Error: {Message}", e.Message);
notificationService.Show("Validation Error", e.Message, NotificationType.Error);
}
}
/// <summary>
/// Shows a prompt and return false if client not connected
/// </summary>
protected async Task<bool> CheckClientConnectedWithPrompt()
{
if (ClientManager.IsConnected)
return true;
var vm = vmFactory.Get<InferenceConnectionHelpViewModel>();
await vm.CreateDialog().ShowAsync();
return ClientManager.IsConnected;
}
/// <summary>
/// Shows a dialog and return false if prompt required extensions not installed
/// </summary>
private async Task<bool> CheckPromptExtensionsInstalled(NodeDictionary nodeDictionary)
{
// Get prompt required extensions
// Just static for now but could do manifest lookup when we support custom workflows
var requiredExtensionSpecifiers = nodeDictionary.RequiredExtensions.ToList();
// Skip if no extensions required
if (requiredExtensionSpecifiers.Count == 0)
{
return true;
}
// Get installed extensions
var localPackagePair = ClientManager.Client?.LocalServerPackage.Unwrap()!;
var manager = localPackagePair.BasePackage.ExtensionManager.Unwrap();
var localExtensions = (
await ((GitPackageExtensionManager)manager).GetInstalledExtensionsLiteAsync(
localPackagePair.InstalledPackage
)
).ToList();
var localExtensionsByGitUrl = localExtensions
.Where(ext => ext.GitRepositoryUrl is not null)
.ToDictionary(ext => ext.GitRepositoryUrl!, ext => ext);
var requiredExtensionReferences = requiredExtensionSpecifiers
.Select(specifier => specifier.Name)
.ToHashSet();
var missingExtensions = new List<ExtensionSpecifier>();
var outOfDateExtensions =
new List<(ExtensionSpecifier Specifier, InstalledPackageExtension Installed)>();
// Check missing extensions and out of date extensions
foreach (var specifier in requiredExtensionSpecifiers)
{
if (!localExtensionsByGitUrl.TryGetValue(specifier.Name, out var localExtension))
{
missingExtensions.Add(specifier);
continue;
}
// Check if constraint is specified
if (specifier.Constraint is not null && specifier.TryGetSemVersionRange(out var semVersionRange))
{
// Get version to compare
localExtension = await manager.GetInstalledExtensionInfoAsync(localExtension);
// Try to parse local tag to semver
if (
localExtension.Version?.Tag is not null
&& SemVersion.TryParse(
localExtension.Version.Tag,
SemVersionStyles.AllowV,
out var localSemVersion
)
)
{
// Check if not satisfied
if (!semVersionRange.Contains(localSemVersion))
{
outOfDateExtensions.Add((specifier, localExtension));
}
}
}
}
if (missingExtensions.Count == 0 && outOfDateExtensions.Count == 0)
{
return true;
}
var dialog = DialogHelper.CreateMarkdownDialog(
$"#### The following extensions are required for this workflow:\n"
+ $"{string.Join("\n- ", missingExtensions.Select(ext => ext.Name))}"
+ $"{string.Join("\n- ", outOfDateExtensions.Select(pair => $"{pair.Item1.Name} {pair.Specifier.Constraint} {pair.Specifier.Version} (Current Version: {pair.Installed.Version?.Tag})"))}",
"Install Required Extensions?"
);
dialog.IsPrimaryButtonEnabled = true;
dialog.DefaultButton = ContentDialogButton.Primary;
dialog.PrimaryButtonText =
$"{Resources.Action_Install} ({localPackagePair.InstalledPackage.DisplayName.ToRepr()} will restart)";
dialog.CloseButtonText = Resources.Action_Cancel;
if (await dialog.ShowAsync() == ContentDialogResult.Primary)
{
var manifestExtensionsMap = await manager.GetManifestExtensionsMapAsync(
manager.GetManifests(localPackagePair.InstalledPackage)
);
var steps = new List<IPackageStep>();
foreach (var missingExtension in missingExtensions)
{
if (!manifestExtensionsMap.TryGetValue(missingExtension.Name, out var extension))
{
Logger.Warn(
"Extension {MissingExtensionUrl} not found in manifests",
missingExtension.Name
);
continue;
}
steps.Add(new InstallExtensionStep(manager, localPackagePair.InstalledPackage, extension));
}
var runner = new PackageModificationRunner
{
ShowDialogOnStart = true,
ModificationCompleteTitle = "Extensions Installed",
ModificationCompleteMessage = "Finished installing required extensions"
};
EventManager.Instance.OnPackageInstallProgressAdded(runner);
runner
.ExecuteSteps(steps)
.ContinueWith(async _ =>
{
if (runner.Failed)
return;
// Restart Package
// TODO: This should be handled by some DI package manager service
var launchPage = App.Services.GetRequiredService<LaunchPageViewModel>();
try
{
await Dispatcher.UIThread.InvokeAsync(async () =>
{
await launchPage.Stop();
await launchPage.LaunchAsync();
});
}
catch (Exception e)
{
Logger.Error(e, "Error while restarting package");
notificationService.ShowPersistent(
new AppException(
"Could not restart package",
"Please manually restart the package for extension changes to take effect"
)
);
}
})
.SafeFireAndForget();
}
return false;
}
/// <summary>
/// Handles the preview image received event from the websocket.
/// Updates the preview image in the image gallery.
/// </summary>
protected virtual void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
{
ImageGalleryCardViewModel.SetPreviewImage(args.ImageBytes);
}
/// <summary>
/// Handles the progress update received event from the websocket.
/// Updates the progress view model.
/// </summary>
protected virtual void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args)
{
Dispatcher.UIThread.Post(() =>
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})" + (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
private void AttachRunningNodeChangedHandler(ComfyTask comfyTask)
{
// Do initial update
if (comfyTask.RunningNodesHistory.TryPeek(out var lastNode))
{
OnRunningNodeChanged(comfyTask, lastNode);
}
comfyTask.RunningNodeChanged += OnRunningNodeChanged;
}
/// <summary>
/// Handles the node executing updates received event from the websocket.
/// </summary>
protected virtual void OnRunningNodeChanged(object? sender, string? nodeName)
{
// Ignore if regular progress updates started
if (sender is not ComfyTask { HasProgressUpdateStarted: false })
{
return;
}
Dispatcher.UIThread.Post(() =>
{
OutputProgress.IsIndeterminate = true;
OutputProgress.Value = 100;
OutputProgress.Maximum = 100;
OutputProgress.Text = nodeName;
});
}
public class ImageGenerationEventArgs : EventArgs
{
public required ComfyClient Client { get; init; }
public required NodeDictionary Nodes { get; init; }
public required IReadOnlyList<string> OutputNodeNames { get; init; }
public int BatchIndex { get; init; }
public GenerationParameters? Parameters { get; init; }
public InferenceProjectDocument? Project { get; init; }
public bool ClearOutputImages { get; init; } = true;
public List<(string SourcePath, string DestinationRelativePath)> FilesToTransfer { get; init; } = [];
}
public class BuildPromptEventArgs : EventArgs
{
public ComfyNodeBuilder Builder { get; } = new();
public GenerateOverrides Overrides { get; init; } = new();
public long? SeedOverride { get; init; }
public List<(string SourcePath, string DestinationRelativePath)> FilesToTransfer { get; init; } = [];
public ModuleApplyStepEventArgs ToModuleApplyStepEventArgs()
{
var overrides = new Dictionary<Type, bool>();
if (Overrides.IsHiresFixEnabled.HasValue)
{
overrides[typeof(HiresFixModule)] = Overrides.IsHiresFixEnabled.Value;
}
return new ModuleApplyStepEventArgs
{
Builder = Builder,
IsEnabledOverrides = overrides,
FilesToTransfer = FilesToTransfer
};
}
public static implicit operator ModuleApplyStepEventArgs(BuildPromptEventArgs args)
{
return args.ToModuleApplyStepEventArgs();
}
}
}