diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs index a6303f21..e04f4f46 100644 --- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs +++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs @@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions builder.Connections.ImageSize = builder.Connections.LatentSize; } - var saveImage = builder.Nodes.AddNamedNode( + var previewImage = builder.Nodes.AddNamedNode( new NamedComfyNode("SaveImage") { - ClassType = "SaveImage", - Inputs = new Dictionary - { - ["filename_prefix"] = "Inference/TextToImage", - ["images"] = builder.Connections.Image - } + ClassType = "PreviewImage", + Inputs = new Dictionary { ["images"] = builder.Connections.Image } } ); - builder.Connections.OutputNodes.Add(saveImage); + builder.Connections.OutputNodes.Add(previewImage); - return saveImage.Name; + return previewImage.Name; } } diff --git a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs index 28c215d6..a090c6a2 100644 --- a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs +++ b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs @@ -13,50 +13,57 @@ public static class ImageProcessor /// public static (int rows, int columns) GetGridDimensionsFromImageCount(int count) { - if (count <= 1) return (1, 1); - if (count == 2) return (1, 2); - + if (count <= 1) + return (1, 1); + if (count == 2) + return (1, 2); + // Prefer one extra row over one extra column, // the row count will be the floor of the square root // and the column count will be floor of count / rows - var rows = (int) Math.Floor(Math.Sqrt(count)); - var columns = (int) Math.Floor((double) count / rows); + var rows = (int)Math.Floor(Math.Sqrt(count)); + var columns = (int)Math.Floor((double)count / rows); return (rows, columns); } - - public static SKImage CreateImageGrid( - IReadOnlyList images, - int spacing = 0) + + public static SKImage CreateImageGrid(IReadOnlyList images, int spacing = 0) { + if (images.Count == 0) + throw new ArgumentException("Must have at least one image"); + var (rows, columns) = GetGridDimensionsFromImageCount(images.Count); var singleWidth = images[0].Width; var singleHeight = images[0].Height; - + // Make output image using var output = new SKBitmap( - singleWidth * columns + spacing * (columns - 1), - singleHeight * rows + spacing * (rows - 1)); - + singleWidth * columns + spacing * (columns - 1), + singleHeight * rows + spacing * (rows - 1) + ); + // Draw images using var canvas = new SKCanvas(output); - - foreach (var (row, column) in - Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))) + + foreach ( + var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns)) + ) { // Stop if we have drawn all images var index = row * columns + column; - if (index >= images.Count) break; - + if (index >= images.Count) + break; + // Get image var image = images[index]; - + // Draw image var destination = new SKRect( singleWidth * column + spacing * column, singleHeight * row + spacing * row, singleWidth * column + spacing * column + image.Width, - singleHeight * row + spacing * row + image.Height); + singleHeight * row + spacing * row + image.Height + ); canvas.DrawImage(image, destination); } diff --git a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs index 3e71a877..aec47d07 100644 --- a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -345,6 +346,61 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken); } + private async Task MigrateLinksIfNeeded(PackagePair packagePair) + { + if (packagePair.InstalledPackage.FullPath is not { } packagePath) + { + throw new ArgumentException("Package path is null", nameof(packagePair)); + } + + var newDestination = settingsManager.ImagesInferenceDirectory; + + // If new destination is a reparse point (like before, delete it first) + if (newDestination is { IsSymbolicLink: true, Info.LinkTarget: null }) + { + logger.LogInformation("Deleting existing link target at {NewDir}", newDestination); + newDestination.Info.Attributes = FileAttributes.Normal; + await newDestination.DeleteAsync(false).ConfigureAwait(false); + } + + newDestination.Create(); + + // For locally installed packages only + // Move all files in ./output/Inference to /Images/Inference and delete ./output/Inference + + var legacyLinkSource = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir( + "output", + "Inference" + ); + if (!legacyLinkSource.Exists) + { + return; + } + + // Move files if not empty + if (legacyLinkSource.Info.EnumerateFiles().Any()) + { + logger.LogInformation( + "Moving files from {LegacyDir} to {NewDir}", + legacyLinkSource, + newDestination + ); + await FileTransfers + .MoveAllFilesAndDirectories( + legacyLinkSource, + newDestination, + overwriteIfHashMatches: true, + overwrite: false + ) + .ConfigureAwait(false); + } + + // Delete legacy link + logger.LogInformation("Deleting legacy link at {LegacyDir}", legacyLinkSource); + legacyLinkSource.Info.Attributes = FileAttributes.Normal; + await legacyLinkSource.DeleteAsync(false).ConfigureAwait(false); + } + /// public async Task ConnectAsync( PackagePair packagePair, @@ -367,11 +423,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient logger.LogError(ex, "Error setting up completion provider"); }); - // Setup image folder links - await comfyPackage.SetupInferenceOutputFolderLinks( - packagePair.InstalledPackage.FullPath - ?? throw new InvalidOperationException("Package does not have a Path") - ); + await MigrateLinksIfNeeded(packagePair); // Get user defined host and port var host = packagePair.InstalledPackage.GetLaunchArgsHost(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index 3bd7e614..c7fe35e0 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -3,11 +3,14 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; +using Avalonia.Controls.Notifications; +using Avalonia.Media.Imaging; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using NLog; @@ -27,6 +30,8 @@ using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Base; @@ -41,6 +46,7 @@ public abstract partial class InferenceGenerationViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + private readonly ISettingsManager settingsManager; private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; @@ -60,11 +66,13 @@ public abstract partial class InferenceGenerationViewModelBase protected InferenceGenerationViewModelBase( ServiceManager vmFactory, IInferenceClientManager inferenceClientManager, - INotificationService notificationService + INotificationService notificationService, + ISettingsManager settingsManager ) : base(notificationService) { this.notificationService = notificationService; + this.settingsManager = settingsManager; this.vmFactory = vmFactory; ClientManager = inferenceClientManager; @@ -75,6 +83,100 @@ public abstract partial class InferenceGenerationViewModelBase GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); } + /// + /// Write an image to the default output folder + /// + protected Task WriteOutputImageAsync( + Stream imageStream, + ImageGenerationEventArgs args, + int batchNum = 0, + int batchTotal = 0, + bool isGrid = false + ) + { + var defaultOutputDir = settingsManager.ImagesInferenceDirectory; + defaultOutputDir.Create(); + + return WriteOutputImageAsync( + imageStream, + defaultOutputDir, + args, + batchNum, + batchTotal, + isGrid + ); + } + + /// + /// Write an image to an output folder + /// + protected async Task WriteOutputImageAsync( + Stream imageStream, + DirectoryPath outputDir, + ImageGenerationEventArgs args, + int batchNum = 0, + int batchTotal = 0, + bool isGrid = false + ) + { + var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat; + + var formatProvider = new FileNameFormatProvider + { + GenerationParameters = args.Parameters, + ProjectType = args.Project?.ProjectType, + ProjectName = ProjectFile?.NameWithoutExtension + }; + + // Parse to format + if ( + string.IsNullOrEmpty(formatTemplateStr) + || !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format) + ) + { + // Fallback to default + Logger.Warn( + "Failed to parse format template: {FormatTemplate}, using default", + formatTemplateStr + ); + + format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider); + } + + if (isGrid) + { + format = format.WithGridPrefix(); + } + + if (batchNum >= 1 && batchTotal > 1) + { + format = format.WithBatchPostFix(batchNum, batchTotal); + } + + var fileName = format.GetFileName() + ".png"; + var file = outputDir.JoinFile(fileName); + + // Until the file is free, keep adding _{i} to the end + for (var i = 0; i < 100; i++) + { + if (!file.Exists) + break; + + file = outputDir.JoinFile($"{fileName}_{i + 1}"); + } + + // If that fails, append an 7-char uuid + if (file.Exists) + { + file = outputDir.JoinFile($"{fileName}_{Guid.NewGuid():N}"[..7]); + } + + await using var fileStream = file.Info.OpenWrite(); + await imageStream.CopyToAsync(fileStream); + + return file; + } + /// /// Builds the image generation prompt /// @@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase // Wait for prompt to finish await promptTask.Task.WaitAsync(cancellationToken); - Logger.Trace($"Prompt task {promptTask.Id} finished"); + Logger.Debug($"Prompt task {promptTask.Id} finished"); // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( @@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase cancellationToken ); + if ( + !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) + || images is not { Count: > 0 } + ) + { + // No images match + notificationService.Show( + "No output", + "Did not receive any output images", + NotificationType.Warning + ); + return; + } + // Disable cancellation await promptInterrupt.DisposeAsync(); @@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase ImageGalleryCardViewModel.ImageSources.Clear(); } - if ( - !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null - ) - { - // No images match - notificationService.Show("No output", "Did not receive any output images"); - return; - } - await ProcessOutputImages(images, args); } finally @@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase ImageGenerationEventArgs args ) { + var client = args.Client; + // Write metadata to images + var outputImagesBytes = new List(); var outputImages = new List(); - foreach ( - var (i, filePath) in images - .Select(image => image.ToFilePath(args.Client.OutputImagesDir!)) - .Enumerate() - ) + + foreach (var (i, comfyImage) in images.Enumerate()) { - if (!filePath.Exists) - { - Logger.Warn($"Image file {filePath} does not exist"); - continue; - } + Logger.Debug("Downloading image: {FileName}", comfyImage.FileName); + var imageStream = await client.GetImageStreamAsync(comfyImage); + + using var ms = new MemoryStream(); + await imageStream.CopyToAsync(ms); + + var imageArray = ms.ToArray(); + outputImagesBytes.Add(imageArray); var parameters = args.Parameters!; var project = args.Project!; @@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase ); } - var bytesWithMetadata = PngDataHelper.AddMetadata( - await filePath.ReadAllBytesAsync(), - parameters, - project - ); + var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project); - await using (var outputStream = filePath.Info.OpenWrite()) - { - await outputStream.WriteAsync(bytesWithMetadata); - await outputStream.FlushAsync(); - } + // Write using generated name + var filePath = await WriteOutputImageAsync( + new MemoryStream(bytesWithMetadata), + args, + i + 1, + images.Count + ); outputImages.Add(new ImageSource(filePath)); @@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase // Download all images to make grid, if multiple if (outputImages.Count > 1) { - var outputDir = outputImages[0].LocalFile!.Directory; - - var loadedImages = outputImages - .Select(i => i.LocalFile) - .Where(f => f is { Exists: true }) - .Select(f => - { - using var stream = f!.Info.OpenRead(); - return SKImage.FromEncodedData(stream); - }) - .ToImmutableArray(); + var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray(); var project = args.Project!; @@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase ); // Save to disk - var lastName = outputImages.Last().LocalFile?.Info.Name; - var gridPath = outputDir!.JoinFile($"grid-{lastName}"); - - await using (var fileStream = gridPath.Info.OpenWrite()) - { - await fileStream.WriteAsync(gridBytesWithMetadata); - } + var gridPath = await WriteOutputImageAsync( + new MemoryStream(gridBytesWithMetadata), + args, + isGrid: true + ); // Insert to start of images var gridImage = new ImageSource(gridPath); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs index 9e56a16d..d97cf780 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs @@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; +using StabilityMatrix.Core.Services; using Path = System.IO.Path; #pragma warning disable CS0657 // Not a valid attribute location for this declaration @@ -60,9 +61,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase public InferenceImageUpscaleViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, + ISettingsManager settingsManager, ServiceManager vmFactory ) - : base(vmFactory, inferenceClientManager, notificationService) + : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 07124ac0..929aa771 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -86,10 +86,11 @@ public class InferenceTextToImageViewModel public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, + ISettingsManager settingsManager, ServiceManager vmFactory, IModelIndexService modelIndexService ) - : base(vmFactory, inferenceClientManager, notificationService) + : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; this.modelIndexService = modelIndexService;