From a40e4e5a989399ba43ab106b10814a7d7b9361fa Mon Sep 17 00:00:00 2001 From: Ionite Date: Wed, 11 Oct 2023 23:57:22 -0400 Subject: [PATCH 01/13] Add file name formatting frameworks --- .../Models/Inference/FileNameFormat.cs | 68 +++++++++++ .../Models/Inference/FileNameFormatPart.cs | 5 + .../Inference/FileNameFormatProvider.cs | 111 ++++++++++++++++++ .../Avalonia/FileNameFormatProviderTests.cs | 25 ++++ .../Avalonia/FileNameFormatTests.cs | 24 ++++ 5 files changed, 233 insertions(+) create mode 100644 StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs create mode 100644 StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs create mode 100644 StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs create mode 100644 StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs create mode 100644 StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs new file mode 100644 index 00000000..28e9a35a --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormat +{ + public string Template { get; } + + public string Prefix { get; set; } = ""; + + public string Postfix { get; set; } = ""; + + public IReadOnlyList Parts { get; } + + private FileNameFormat(string template, IReadOnlyList parts) + { + Template = template; + Parts = parts; + } + + public FileNameFormat WithBatchPostFix(int current, int total) + { + return this with { Postfix = Postfix + $" ({current}-{total})" }; + } + + public FileNameFormat WithGridPrefix() + { + return this with { Prefix = Prefix + "Grid_" }; + } + + public string GetFileName() + { + return Prefix + + string.Join("", Parts.Select(p => p.Constant ?? p.Substitution?.Invoke() ?? "")) + + Postfix; + } + + public static FileNameFormat Parse(string template, FileNameFormatProvider provider) + { + provider.Validate(template); + var parts = provider.GetParts(template).ToImmutableArray(); + return new FileNameFormat(template, parts); + } + + public static bool TryParse( + string template, + FileNameFormatProvider provider, + [NotNullWhen(true)] out FileNameFormat? format + ) + { + try + { + format = Parse(template, provider); + return true; + } + catch (ArgumentException) + { + format = null; + return false; + } + } + + public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}"; +} diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs new file mode 100644 index 00000000..bfbcc8d9 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs @@ -0,0 +1,5 @@ +using System; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormatPart(string? Constant, Func? Substitution); diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs new file mode 100644 index 00000000..e6ecc563 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public partial class FileNameFormatProvider +{ + public GenerationParameters? GenerationParameters { get; init; } + + public InferenceProjectType? ProjectType { get; init; } + + public string? ProjectName { get; init; } + + private Dictionary>? _substitutions; + + private Dictionary> Substitutions => + _substitutions ??= new Dictionary> + { + { "seed", () => GenerationParameters?.Seed.ToString() }, + { "model_name", () => GenerationParameters?.ModelName }, + { "model_hash", () => GenerationParameters?.ModelHash }, + { "width", () => GenerationParameters?.Width.ToString() }, + { "height", () => GenerationParameters?.Height.ToString() }, + { "project_type", () => ProjectType?.GetStringValue() }, + { "project_name", () => ProjectName }, + { "date", () => DateTime.Now.ToString("yyyy-MM-dd") }, + { "time", () => DateTime.Now.ToString("HH-mm-ss") } + }; + + public (int Current, int Total)? BatchInfo { get; init; } + + /// + /// Validate a format string + /// + public void Validate(string format) + { + var regex = BracketRegex(); + var matches = regex.Matches(format); + var variables = matches.Select(m => m.Value[1..^1]).ToList(); + + foreach (var variable in variables) + { + if (!Substitutions.ContainsKey(variable)) + { + throw new ArgumentException($"Unknown variable '{variable}'"); + } + } + } + + public IEnumerable GetParts(string template) + { + var regex = BracketRegex(); + var matches = regex.Matches(template); + + var parts = new List(); + + // Loop through all parts of the string, including matches and non-matches + var currentIndex = 0; + + foreach (var result in matches.Cast()) + { + // If the match is not at the start of the string, add a constant part + if (result.Index != currentIndex) + { + var constant = template[currentIndex..result.Index]; + parts.Add(new FileNameFormatPart(constant, null)); + + currentIndex += constant.Length; + } + + var variable = result.Value[1..^1]; + parts.Add(new FileNameFormatPart(null, Substitutions[variable])); + + currentIndex += result.Length; + } + + // Add remaining as constant + if (currentIndex != template.Length) + { + var constant = template[currentIndex..]; + parts.Add(new FileNameFormatPart(constant, null)); + } + + return parts; + } + + /// + /// Return a string substituting the variables in the format string + /// + private string? GetSubstitution(string variable) + { + return variable switch + { + "seed" => GenerationParameters.Seed.ToString(), + "model_name" => GenerationParameters.ModelName, + "model_hash" => GenerationParameters.ModelHash, + "width" => GenerationParameters.Width.ToString(), + "height" => GenerationParameters.Height.ToString(), + "date" => DateTime.Now.ToString("yyyy-MM-dd"), + "time" => DateTime.Now.ToString("HH-mm-ss"), + _ => throw new ArgumentOutOfRangeException(nameof(variable), variable, null) + }; + } + + [GeneratedRegex(@"\{[a-z_]+\}")] + private static partial Regex BracketRegex(); +} diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs new file mode 100644 index 00000000..cdf7fdfa --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs @@ -0,0 +1,25 @@ +using StabilityMatrix.Avalonia.Models.Inference; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class FileNameFormatProviderTests +{ + [TestMethod] + public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow() + { + var provider = new FileNameFormatProvider(); + + provider.Validate("{date}_{time}-{model_name}-{seed}"); + } + + [TestMethod] + public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow() + { + var provider = new FileNameFormatProvider(); + + Assert.ThrowsException( + () => provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}") + ); + } +} diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs new file mode 100644 index 00000000..0da1eb84 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs @@ -0,0 +1,24 @@ +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class FileNameFormatTests +{ + [TestMethod] + public void TestFileNameFormatParse() + { + var provider = new FileNameFormatProvider + { + GenerationParameters = new GenerationParameters { Seed = 123 }, + ProjectName = "uwu", + ProjectType = InferenceProjectType.TextToImage, + }; + + var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider); + + Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName()); + } +} From 10141c00a12e08fdb9d536619471186fc8ece58c Mon Sep 17 00:00:00 2001 From: Ionite Date: Wed, 11 Oct 2023 23:57:54 -0400 Subject: [PATCH 02/13] Add InferenceOutputImageFileNameFormat Setting --- .../ViewModels/SettingsViewModel.cs | 10 +++++++ .../Views/SettingsPage.axaml | 26 +++++++++++++++++-- .../Models/Settings/Settings.cs | 5 ++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs index eee148fa..97772909 100644 --- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs @@ -107,6 +107,9 @@ public partial class SettingsViewModel : PageViewModelBase [ObservableProperty] private bool isCompletionRemoveUnderscoresEnabled = true; + [ObservableProperty] + private string? outputImageFileNameFormat; + [ObservableProperty] private bool isImageViewerPixelGridEnabled = true; @@ -201,6 +204,13 @@ public partial class SettingsViewModel : PageViewModelBase true ); + settingsManager.RelayPropertyFor( + this, + vm => vm.OutputImageFileNameFormat, + settings => settings.InferenceOutputImageFileNameFormat, + true + ); + settingsManager.RelayPropertyFor( this, vm => vm.IsImageViewerPixelGridEnabled, diff --git a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml index 21ab36ac..d51bce28 100644 --- a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml +++ b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml @@ -6,10 +6,12 @@ xmlns:controls="clr-namespace:StabilityMatrix.Avalonia.Controls" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" + xmlns:fluentIcons="clr-namespace:FluentIcons.FluentAvalonia;assembly=FluentIcons.FluentAvalonia" xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData" xmlns:ui="using:FluentAvalonia.UI.Controls" xmlns:vm="clr-namespace:StabilityMatrix.Avalonia.ViewModels" xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages" + xmlns:avaloniaEdit="https://github.com/avaloniaui/avaloniaedit" d:DataContext="{x:Static mocks:DesignData.SettingsViewModel}" d:DesignHeight="700" d:DesignWidth="800" @@ -83,10 +85,10 @@ - + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Core/Models/Settings/Settings.cs b/StabilityMatrix.Core/Models/Settings/Settings.cs index cb01f14e..73d02cb6 100644 --- a/StabilityMatrix.Core/Models/Settings/Settings.cs +++ b/StabilityMatrix.Core/Models/Settings/Settings.cs @@ -70,6 +70,11 @@ public class Settings /// public bool IsCompletionRemoveUnderscoresEnabled { get; set; } = true; + /// + /// Format for Inference output image file names + /// + public string? InferenceOutputImageFileNameFormat { get; set; } + /// /// Whether the Inference Image Viewer shows pixel grids at high zoom levels /// From 9cacf9283bd32d87c76e6a7d2fd9840fc42bdbb3 Mon Sep 17 00:00:00 2001 From: Ionite Date: Wed, 11 Oct 2023 23:58:12 -0400 Subject: [PATCH 03/13] Add Stream overload for Png AddMetadata --- StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs b/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs index 086d0785..aa31add4 100644 --- a/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs @@ -16,6 +16,17 @@ public static class PngDataHelper private static readonly byte[] Text = { 0x74, 0x45, 0x58, 0x74 }; private static readonly byte[] Iend = { 0x49, 0x45, 0x4E, 0x44 }; + public static byte[] AddMetadata( + Stream inputStream, + GenerationParameters generationParameters, + InferenceProjectDocument projectDocument + ) + { + using var ms = new MemoryStream(); + inputStream.CopyTo(ms); + return AddMetadata(ms.ToArray(), generationParameters, projectDocument); + } + public static byte[] AddMetadata( byte[] inputImage, GenerationParameters generationParameters, From 2f033893140d1a05c3d087c3498379417470b751 Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 00:13:41 -0400 Subject: [PATCH 04/13] Refresh Info before checking access time --- StabilityMatrix.Core/Models/Database/LocalImageFile.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/StabilityMatrix.Core/Models/Database/LocalImageFile.cs b/StabilityMatrix.Core/Models/Database/LocalImageFile.cs index 184b9c63..bf728a1b 100644 --- a/StabilityMatrix.Core/Models/Database/LocalImageFile.cs +++ b/StabilityMatrix.Core/Models/Database/LocalImageFile.cs @@ -126,6 +126,8 @@ public class LocalImageFile GenerationParameters.TryParse(metadata, out genParams); } + filePath.Info.Refresh(); + return new LocalImageFile { RelativePath = relativePath, From a2c3acb95256651a41f29d242f35d6ef2b16e07b Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 01:02:29 -0400 Subject: [PATCH 05/13] Change Inference to use downloaded images and custom file name formatting --- .../Extensions/ComfyNodeBuilderExtensions.cs | 14 +- .../Helpers/ImageProcessor.cs | 47 +++-- .../Services/InferenceClientManager.cs | 62 +++++- .../Base/InferenceGenerationViewModelBase.cs | 194 +++++++++++++----- .../InferenceImageUpscaleViewModel.cs | 4 +- .../InferenceTextToImageViewModel.cs | 3 +- 6 files changed, 239 insertions(+), 85 deletions(-) diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs index a6303f21..e04f4f46 100644 --- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs +++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs @@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions builder.Connections.ImageSize = builder.Connections.LatentSize; } - var saveImage = builder.Nodes.AddNamedNode( + var previewImage = builder.Nodes.AddNamedNode( new NamedComfyNode("SaveImage") { - ClassType = "SaveImage", - Inputs = new Dictionary - { - ["filename_prefix"] = "Inference/TextToImage", - ["images"] = builder.Connections.Image - } + ClassType = "PreviewImage", + Inputs = new Dictionary { ["images"] = builder.Connections.Image } } ); - builder.Connections.OutputNodes.Add(saveImage); + builder.Connections.OutputNodes.Add(previewImage); - return saveImage.Name; + return previewImage.Name; } } diff --git a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs index 28c215d6..a090c6a2 100644 --- a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs +++ b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs @@ -13,50 +13,57 @@ public static class ImageProcessor /// public static (int rows, int columns) GetGridDimensionsFromImageCount(int count) { - if (count <= 1) return (1, 1); - if (count == 2) return (1, 2); - + if (count <= 1) + return (1, 1); + if (count == 2) + return (1, 2); + // Prefer one extra row over one extra column, // the row count will be the floor of the square root // and the column count will be floor of count / rows - var rows = (int) Math.Floor(Math.Sqrt(count)); - var columns = (int) Math.Floor((double) count / rows); + var rows = (int)Math.Floor(Math.Sqrt(count)); + var columns = (int)Math.Floor((double)count / rows); return (rows, columns); } - - public static SKImage CreateImageGrid( - IReadOnlyList images, - int spacing = 0) + + public static SKImage CreateImageGrid(IReadOnlyList images, int spacing = 0) { + if (images.Count == 0) + throw new ArgumentException("Must have at least one image"); + var (rows, columns) = GetGridDimensionsFromImageCount(images.Count); var singleWidth = images[0].Width; var singleHeight = images[0].Height; - + // Make output image using var output = new SKBitmap( - singleWidth * columns + spacing * (columns - 1), - singleHeight * rows + spacing * (rows - 1)); - + singleWidth * columns + spacing * (columns - 1), + singleHeight * rows + spacing * (rows - 1) + ); + // Draw images using var canvas = new SKCanvas(output); - - foreach (var (row, column) in - Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))) + + foreach ( + var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns)) + ) { // Stop if we have drawn all images var index = row * columns + column; - if (index >= images.Count) break; - + if (index >= images.Count) + break; + // Get image var image = images[index]; - + // Draw image var destination = new SKRect( singleWidth * column + spacing * column, singleHeight * row + spacing * row, singleWidth * column + spacing * column + image.Width, - singleHeight * row + spacing * row + image.Height); + singleHeight * row + spacing * row + image.Height + ); canvas.DrawImage(image, destination); } diff --git a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs index 3e71a877..aec47d07 100644 --- a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -345,6 +346,61 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken); } + private async Task MigrateLinksIfNeeded(PackagePair packagePair) + { + if (packagePair.InstalledPackage.FullPath is not { } packagePath) + { + throw new ArgumentException("Package path is null", nameof(packagePair)); + } + + var newDestination = settingsManager.ImagesInferenceDirectory; + + // If new destination is a reparse point (like before, delete it first) + if (newDestination is { IsSymbolicLink: true, Info.LinkTarget: null }) + { + logger.LogInformation("Deleting existing link target at {NewDir}", newDestination); + newDestination.Info.Attributes = FileAttributes.Normal; + await newDestination.DeleteAsync(false).ConfigureAwait(false); + } + + newDestination.Create(); + + // For locally installed packages only + // Move all files in ./output/Inference to /Images/Inference and delete ./output/Inference + + var legacyLinkSource = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir( + "output", + "Inference" + ); + if (!legacyLinkSource.Exists) + { + return; + } + + // Move files if not empty + if (legacyLinkSource.Info.EnumerateFiles().Any()) + { + logger.LogInformation( + "Moving files from {LegacyDir} to {NewDir}", + legacyLinkSource, + newDestination + ); + await FileTransfers + .MoveAllFilesAndDirectories( + legacyLinkSource, + newDestination, + overwriteIfHashMatches: true, + overwrite: false + ) + .ConfigureAwait(false); + } + + // Delete legacy link + logger.LogInformation("Deleting legacy link at {LegacyDir}", legacyLinkSource); + legacyLinkSource.Info.Attributes = FileAttributes.Normal; + await legacyLinkSource.DeleteAsync(false).ConfigureAwait(false); + } + /// public async Task ConnectAsync( PackagePair packagePair, @@ -367,11 +423,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient logger.LogError(ex, "Error setting up completion provider"); }); - // Setup image folder links - await comfyPackage.SetupInferenceOutputFolderLinks( - packagePair.InstalledPackage.FullPath - ?? throw new InvalidOperationException("Package does not have a Path") - ); + await MigrateLinksIfNeeded(packagePair); // Get user defined host and port var host = packagePair.InstalledPackage.GetLaunchArgsHost(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index 3bd7e614..c7fe35e0 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -3,11 +3,14 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; +using Avalonia.Controls.Notifications; +using Avalonia.Media.Imaging; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using NLog; @@ -27,6 +30,8 @@ using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Base; @@ -41,6 +46,7 @@ public abstract partial class InferenceGenerationViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + private readonly ISettingsManager settingsManager; private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; @@ -60,11 +66,13 @@ public abstract partial class InferenceGenerationViewModelBase protected InferenceGenerationViewModelBase( ServiceManager vmFactory, IInferenceClientManager inferenceClientManager, - INotificationService notificationService + INotificationService notificationService, + ISettingsManager settingsManager ) : base(notificationService) { this.notificationService = notificationService; + this.settingsManager = settingsManager; this.vmFactory = vmFactory; ClientManager = inferenceClientManager; @@ -75,6 +83,100 @@ public abstract partial class InferenceGenerationViewModelBase GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); } + /// + /// Write an image to the default output folder + /// + protected Task WriteOutputImageAsync( + Stream imageStream, + ImageGenerationEventArgs args, + int batchNum = 0, + int batchTotal = 0, + bool isGrid = false + ) + { + var defaultOutputDir = settingsManager.ImagesInferenceDirectory; + defaultOutputDir.Create(); + + return WriteOutputImageAsync( + imageStream, + defaultOutputDir, + args, + batchNum, + batchTotal, + isGrid + ); + } + + /// + /// Write an image to an output folder + /// + protected async Task WriteOutputImageAsync( + Stream imageStream, + DirectoryPath outputDir, + ImageGenerationEventArgs args, + int batchNum = 0, + int batchTotal = 0, + bool isGrid = false + ) + { + var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat; + + var formatProvider = new FileNameFormatProvider + { + GenerationParameters = args.Parameters, + ProjectType = args.Project?.ProjectType, + ProjectName = ProjectFile?.NameWithoutExtension + }; + + // Parse to format + if ( + string.IsNullOrEmpty(formatTemplateStr) + || !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format) + ) + { + // Fallback to default + Logger.Warn( + "Failed to parse format template: {FormatTemplate}, using default", + formatTemplateStr + ); + + format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider); + } + + if (isGrid) + { + format = format.WithGridPrefix(); + } + + if (batchNum >= 1 && batchTotal > 1) + { + format = format.WithBatchPostFix(batchNum, batchTotal); + } + + var fileName = format.GetFileName() + ".png"; + var file = outputDir.JoinFile(fileName); + + // Until the file is free, keep adding _{i} to the end + for (var i = 0; i < 100; i++) + { + if (!file.Exists) + break; + + file = outputDir.JoinFile($"{fileName}_{i + 1}"); + } + + // If that fails, append an 7-char uuid + if (file.Exists) + { + file = outputDir.JoinFile($"{fileName}_{Guid.NewGuid():N}"[..7]); + } + + await using var fileStream = file.Info.OpenWrite(); + await imageStream.CopyToAsync(fileStream); + + return file; + } + /// /// Builds the image generation prompt /// @@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase // Wait for prompt to finish await promptTask.Task.WaitAsync(cancellationToken); - Logger.Trace($"Prompt task {promptTask.Id} finished"); + Logger.Debug($"Prompt task {promptTask.Id} finished"); // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( @@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase cancellationToken ); + if ( + !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) + || images is not { Count: > 0 } + ) + { + // No images match + notificationService.Show( + "No output", + "Did not receive any output images", + NotificationType.Warning + ); + return; + } + // Disable cancellation await promptInterrupt.DisposeAsync(); @@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase ImageGalleryCardViewModel.ImageSources.Clear(); } - if ( - !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null - ) - { - // No images match - notificationService.Show("No output", "Did not receive any output images"); - return; - } - await ProcessOutputImages(images, args); } finally @@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase ImageGenerationEventArgs args ) { + var client = args.Client; + // Write metadata to images + var outputImagesBytes = new List(); var outputImages = new List(); - foreach ( - var (i, filePath) in images - .Select(image => image.ToFilePath(args.Client.OutputImagesDir!)) - .Enumerate() - ) + + foreach (var (i, comfyImage) in images.Enumerate()) { - if (!filePath.Exists) - { - Logger.Warn($"Image file {filePath} does not exist"); - continue; - } + Logger.Debug("Downloading image: {FileName}", comfyImage.FileName); + var imageStream = await client.GetImageStreamAsync(comfyImage); + + using var ms = new MemoryStream(); + await imageStream.CopyToAsync(ms); + + var imageArray = ms.ToArray(); + outputImagesBytes.Add(imageArray); var parameters = args.Parameters!; var project = args.Project!; @@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase ); } - var bytesWithMetadata = PngDataHelper.AddMetadata( - await filePath.ReadAllBytesAsync(), - parameters, - project - ); + var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project); - await using (var outputStream = filePath.Info.OpenWrite()) - { - await outputStream.WriteAsync(bytesWithMetadata); - await outputStream.FlushAsync(); - } + // Write using generated name + var filePath = await WriteOutputImageAsync( + new MemoryStream(bytesWithMetadata), + args, + i + 1, + images.Count + ); outputImages.Add(new ImageSource(filePath)); @@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase // Download all images to make grid, if multiple if (outputImages.Count > 1) { - var outputDir = outputImages[0].LocalFile!.Directory; - - var loadedImages = outputImages - .Select(i => i.LocalFile) - .Where(f => f is { Exists: true }) - .Select(f => - { - using var stream = f!.Info.OpenRead(); - return SKImage.FromEncodedData(stream); - }) - .ToImmutableArray(); + var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray(); var project = args.Project!; @@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase ); // Save to disk - var lastName = outputImages.Last().LocalFile?.Info.Name; - var gridPath = outputDir!.JoinFile($"grid-{lastName}"); - - await using (var fileStream = gridPath.Info.OpenWrite()) - { - await fileStream.WriteAsync(gridBytesWithMetadata); - } + var gridPath = await WriteOutputImageAsync( + new MemoryStream(gridBytesWithMetadata), + args, + isGrid: true + ); // Insert to start of images var gridImage = new ImageSource(gridPath); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs index 9e56a16d..d97cf780 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs @@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; +using StabilityMatrix.Core.Services; using Path = System.IO.Path; #pragma warning disable CS0657 // Not a valid attribute location for this declaration @@ -60,9 +61,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase public InferenceImageUpscaleViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, + ISettingsManager settingsManager, ServiceManager vmFactory ) - : base(vmFactory, inferenceClientManager, notificationService) + : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 07124ac0..929aa771 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -86,10 +86,11 @@ public class InferenceTextToImageViewModel public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, + ISettingsManager settingsManager, ServiceManager vmFactory, IModelIndexService modelIndexService ) - : base(vmFactory, inferenceClientManager, notificationService) + : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; this.modelIndexService = modelIndexService; From 62c083bba3b33393ad9c93f93d5c98ddf6f893ae Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 15:23:08 -0400 Subject: [PATCH 06/13] Cleanup unused --- .../Inference/FileNameFormatProvider.cs | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs index e6ecc563..7b4f3508 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -31,8 +31,6 @@ public partial class FileNameFormatProvider { "time", () => DateTime.Now.ToString("HH-mm-ss") } }; - public (int Current, int Total)? BatchInfo { get; init; } - /// /// Validate a format string /// @@ -89,23 +87,8 @@ public partial class FileNameFormatProvider } /// - /// Return a string substituting the variables in the format string + /// Regex for matching contents within a curly brace. /// - private string? GetSubstitution(string variable) - { - return variable switch - { - "seed" => GenerationParameters.Seed.ToString(), - "model_name" => GenerationParameters.ModelName, - "model_hash" => GenerationParameters.ModelHash, - "width" => GenerationParameters.Width.ToString(), - "height" => GenerationParameters.Height.ToString(), - "date" => DateTime.Now.ToString("yyyy-MM-dd"), - "time" => DateTime.Now.ToString("HH-mm-ss"), - _ => throw new ArgumentOutOfRangeException(nameof(variable), variable, null) - }; - } - [GeneratedRegex(@"\{[a-z_]+\}")] private static partial Regex BracketRegex(); } From c422ae7b4bed588397d0e961707c96de989c4255 Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 15:23:32 -0400 Subject: [PATCH 07/13] Fix inference link migration --- .../Services/InferenceClientManager.cs | 59 +++++++------------ 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs index aec47d07..17023dfa 100644 --- a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs @@ -353,52 +353,35 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient throw new ArgumentException("Package path is null", nameof(packagePair)); } - var newDestination = settingsManager.ImagesInferenceDirectory; - - // If new destination is a reparse point (like before, delete it first) - if (newDestination is { IsSymbolicLink: true, Info.LinkTarget: null }) - { - logger.LogInformation("Deleting existing link target at {NewDir}", newDestination); - newDestination.Info.Attributes = FileAttributes.Normal; - await newDestination.DeleteAsync(false).ConfigureAwait(false); - } - - newDestination.Create(); + var inferenceDir = settingsManager.ImagesInferenceDirectory; + inferenceDir.Create(); // For locally installed packages only - // Move all files in ./output/Inference to /Images/Inference and delete ./output/Inference + // Delete ./output/Inference - var legacyLinkSource = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir( - "output", - "Inference" - ); - if (!legacyLinkSource.Exists) - { - return; - } + var legacyInferenceLinkDir = new DirectoryPath( + packagePair.InstalledPackage.FullPath + ).JoinDir("output", "Inference"); - // Move files if not empty - if (legacyLinkSource.Info.EnumerateFiles().Any()) + if (legacyInferenceLinkDir.Exists) { logger.LogInformation( - "Moving files from {LegacyDir} to {NewDir}", - legacyLinkSource, - newDestination + "Deleting legacy inference link at {LegacyDir}", + legacyInferenceLinkDir ); - await FileTransfers - .MoveAllFilesAndDirectories( - legacyLinkSource, - newDestination, - overwriteIfHashMatches: true, - overwrite: false - ) - .ConfigureAwait(false); - } - // Delete legacy link - logger.LogInformation("Deleting legacy link at {LegacyDir}", legacyLinkSource); - legacyLinkSource.Info.Attributes = FileAttributes.Normal; - await legacyLinkSource.DeleteAsync(false).ConfigureAwait(false); + if (legacyInferenceLinkDir.IsSymbolicLink) + { + await legacyInferenceLinkDir.DeleteAsync(false); + } + else + { + logger.LogWarning( + "Legacy inference link at {LegacyDir} is not a symbolic link, skipping", + legacyInferenceLinkDir + ); + } + } } /// From a13fcbc90397603b66d645e5f7ac209d42cb1dde Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 18:13:36 -0400 Subject: [PATCH 08/13] Add sample parameters and validation --- .../Models/Inference/FileNameFormat.cs | 9 +++- .../Models/Inference/FileNameFormatPart.cs | 13 ++++- .../Inference/FileNameFormatProvider.cs | 48 +++++++++++++---- .../Models/Inference/FileNameFormatVar.cs | 8 +++ .../StabilityMatrix.Avalonia.csproj | 1 + .../Base/InferenceGenerationViewModelBase.cs | 1 - .../ViewModels/SettingsViewModel.cs | 54 +++++++++++++++++++ .../Views/SettingsPage.axaml | 26 +++++++-- .../Models/GenerationParameters.cs | 20 +++++++ 9 files changed, 164 insertions(+), 16 deletions(-) create mode 100644 StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs index 28e9a35a..b20f6c48 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.ComponentModel.DataAnnotations; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -35,13 +36,17 @@ public record FileNameFormat public string GetFileName() { return Prefix - + string.Join("", Parts.Select(p => p.Constant ?? p.Substitution?.Invoke() ?? "")) + + string.Join( + "", + Parts.Select( + part => part.Match(constant => constant, substitution => substitution.Invoke()) + ) + ) + Postfix; } public static FileNameFormat Parse(string template, FileNameFormatProvider provider) { - provider.Validate(template); var parts = provider.GetParts(template).ToImmutableArray(); return new FileNameFormat(template, parts); } diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs index bfbcc8d9..9210adc0 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs @@ -1,5 +1,16 @@ using System; +using System.Runtime.InteropServices; +using CSharpDiscriminatedUnion.Attributes; namespace StabilityMatrix.Avalonia.Models.Inference; -public record FileNameFormatPart(string? Constant, Func? Substitution); +[GenerateDiscriminatedUnion(CaseFactoryPrefix = "From")] +[StructLayout(LayoutKind.Auto)] +public readonly partial struct FileNameFormatPart +{ + [StructCase("Constant", isDefaultValue: true)] + private readonly string constant; + + [StructCase("Substitution")] + private readonly Func substitution; +} diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs index 7b4f3508..96cacf04 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -1,7 +1,11 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.Contracts; using System.Linq; using System.Text.RegularExpressions; +using Avalonia.Data; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; @@ -17,7 +21,7 @@ public partial class FileNameFormatProvider private Dictionary>? _substitutions; - private Dictionary> Substitutions => + public Dictionary> Substitutions => _substitutions ??= new Dictionary> { { "seed", () => GenerationParameters?.Seed.ToString() }, @@ -34,19 +38,24 @@ public partial class FileNameFormatProvider /// /// Validate a format string /// - public void Validate(string format) + /// Format string + /// Thrown if the format string contains an unknown variable + [Pure] + public ValidationResult Validate(string format) { var regex = BracketRegex(); var matches = regex.Matches(format); - var variables = matches.Select(m => m.Value[1..^1]).ToList(); + var variables = matches.Select(m => m.Groups[1].Value); foreach (var variable in variables) { if (!Substitutions.ContainsKey(variable)) { - throw new ArgumentException($"Unknown variable '{variable}'"); + return new ValidationResult($"Unknown variable '{variable}'"); } } + + return ValidationResult.Success!; } public IEnumerable GetParts(string template) @@ -65,13 +74,15 @@ public partial class FileNameFormatProvider if (result.Index != currentIndex) { var constant = template[currentIndex..result.Index]; - parts.Add(new FileNameFormatPart(constant, null)); + parts.Add(FileNameFormatPart.FromConstant(constant)); currentIndex += constant.Length; } - var variable = result.Value[1..^1]; - parts.Add(new FileNameFormatPart(null, Substitutions[variable])); + // Now we're at start of the current match, add the variable part + var variable = result.Groups[1].Value; + + parts.Add(FileNameFormatPart.FromSubstitution(Substitutions[variable])); currentIndex += result.Length; } @@ -80,15 +91,34 @@ public partial class FileNameFormatProvider if (currentIndex != template.Length) { var constant = template[currentIndex..]; - parts.Add(new FileNameFormatPart(constant, null)); + parts.Add(FileNameFormatPart.FromConstant(constant)); } return parts; } + /// + /// Return a sample provider for UI preview + /// + public static FileNameFormatProvider GetSample() + { + return new FileNameFormatProvider + { + GenerationParameters = GenerationParameters.GetSample(), + ProjectType = InferenceProjectType.TextToImage, + ProjectName = "Sample Project" + }; + } + /// /// Regex for matching contents within a curly brace. /// - [GeneratedRegex(@"\{[a-z_]+\}")] + [GeneratedRegex(@"\{([a-z_]+)\}")] private static partial Regex BracketRegex(); + + /// + /// Regex for matching a Python-like array index. + /// + [GeneratedRegex(@"\[(?:(?-?\d+)?)\:(?:(?-?\d+)?)?(?:\:(?-?\d+))?\]")] + private static partial Regex IndexRegex(); } diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs new file mode 100644 index 00000000..a453b3bc --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs @@ -0,0 +1,8 @@ +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormatVar +{ + public required string Variable { get; init; } + + public string? Example { get; init; } +} diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj index ea70b345..1c95a536 100644 --- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj +++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj @@ -32,6 +32,7 @@ + diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index c7fe35e0..d3b746f6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -10,7 +10,6 @@ using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls.Notifications; -using Avalonia.Media.Imaging; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using NLog; diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs index 97772909..850a397c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs @@ -3,10 +3,12 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Collections.ObjectModel; using System.ComponentModel; +using System.ComponentModel.DataAnnotations; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; +using System.Reactive.Linq; using System.Reflection; using System.Text; using System.Text.Json; @@ -21,6 +23,7 @@ using Avalonia.Styling; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using DynamicData.Binding; using FluentAvalonia.UI.Controls; using NLog; using SkiaSharp; @@ -29,6 +32,7 @@ using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; @@ -108,8 +112,24 @@ public partial class SettingsViewModel : PageViewModelBase private bool isCompletionRemoveUnderscoresEnabled = true; [ObservableProperty] + [CustomValidation(typeof(SettingsViewModel), nameof(ValidateOutputImageFileNameFormat))] private string? outputImageFileNameFormat; + [ObservableProperty] + private string? outputImageFileNameFormatSample; + + public IEnumerable OutputImageFileNameFormatVars => + FileNameFormatProvider + .GetSample() + .Substitutions.Select( + kv => + new FileNameFormatVar + { + Variable = $"{{{kv.Key}}}", + Example = kv.Value.Invoke() + } + ); + [ObservableProperty] private bool isImageViewerPixelGridEnabled = true; @@ -204,6 +224,32 @@ public partial class SettingsViewModel : PageViewModelBase true ); + this.WhenPropertyChanged(vm => vm.OutputImageFileNameFormat) + .Throttle(TimeSpan.FromMilliseconds(50)) + .Subscribe(formatProperty => + { + var provider = FileNameFormatProvider.GetSample(); + var template = formatProperty.Value; + + if ( + !string.IsNullOrEmpty(template) + && provider.Validate(template) == ValidationResult.Success + ) + { + var format = FileNameFormat.Parse(template, provider); + OutputImageFileNameFormatSample = format.GetFileName() + ".png"; + } + else + { + // Use default format if empty + var defaultFormat = FileNameFormat.Parse( + FileNameFormat.DefaultTemplate, + provider + ); + OutputImageFileNameFormatSample = defaultFormat.GetFileName() + ".png"; + } + }); + settingsManager.RelayPropertyFor( this, vm => vm.OutputImageFileNameFormat, @@ -235,6 +281,14 @@ public partial class SettingsViewModel : PageViewModelBase UpdateAvailableTagCompletionCsvs(); } + public static ValidationResult ValidateOutputImageFileNameFormat( + string format, + ValidationContext context + ) + { + return FileNameFormatProvider.GetSample().Validate(format); + } + partial void OnSelectedThemeChanged(string? value) { // In case design / tests diff --git a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml index d51bce28..18d65bb8 100644 --- a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml +++ b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml @@ -12,6 +12,9 @@ xmlns:vm="clr-namespace:StabilityMatrix.Avalonia.ViewModels" xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages" xmlns:avaloniaEdit="https://github.com/avaloniaui/avaloniaedit" + xmlns:inference="clr-namespace:StabilityMatrix.Avalonia.Models.Inference" + xmlns:mdxaml="https://github.com/whistyun/Markdown.Avalonia.Tight" + Focusable="True" d:DataContext="{x:Static mocks:DesignData.SettingsViewModel}" d:DesignHeight="700" d:DesignWidth="800" @@ -168,15 +171,32 @@ - + FontFamily="Cascadia Code,Consolas,Menlo,Monospace"/> + + + + + diff --git a/StabilityMatrix.Core/Models/GenerationParameters.cs b/StabilityMatrix.Core/Models/GenerationParameters.cs index fab9f793..9ab6c778 100644 --- a/StabilityMatrix.Core/Models/GenerationParameters.cs +++ b/StabilityMatrix.Core/Models/GenerationParameters.cs @@ -126,6 +126,26 @@ public partial record GenerationParameters return (sampler, scheduler); } + /// + /// Return a sample parameters for UI preview + /// + public static GenerationParameters GetSample() + { + return new GenerationParameters + { + PositivePrompt = "(cat:1.2), by artist, detailed, [shaded]", + NegativePrompt = "blurry, jpg artifacts", + Steps = 30, + CfgScale = 7, + Width = 640, + Height = 896, + Seed = 124825529, + ModelName = "ExampleMix7", + ModelHash = "b899d188a1ac7356bfb9399b2277d5b21712aa360f8f9514fba6fcce021baff7", + Sampler = "DPM++ 2M Karras" + }; + } + // Example: Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1 [GeneratedRegex( """^Steps: (?\d+), Sampler: (?.+?), CFG scale: (?\d+(\.\d+)?), Seed: (?\d+), Size: (?\d+)x(?\d+), Model hash: (?.+?), Model: (?.+)$""" From 7d980c08abeab8e38994312dbe50cadf4dcb0429 Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 18:16:54 -0400 Subject: [PATCH 09/13] Fix tests --- .../Avalonia/FileNameFormatProviderTests.cs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs index cdf7fdfa..5905aca0 100644 --- a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs @@ -1,4 +1,5 @@ -using StabilityMatrix.Avalonia.Models.Inference; +using System.ComponentModel.DataAnnotations; +using StabilityMatrix.Avalonia.Models.Inference; namespace StabilityMatrix.Tests.Avalonia; @@ -10,7 +11,8 @@ public class FileNameFormatProviderTests { var provider = new FileNameFormatProvider(); - provider.Validate("{date}_{time}-{model_name}-{seed}"); + var result = provider.Validate("{date}_{time}-{model_name}-{seed}"); + Assert.AreEqual(ValidationResult.Success, result); } [TestMethod] @@ -18,8 +20,9 @@ public class FileNameFormatProviderTests { var provider = new FileNameFormatProvider(); - Assert.ThrowsException( - () => provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}") - ); + var result = provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}"); + Assert.AreNotEqual(ValidationResult.Success, result); + + Assert.AreEqual("Unknown variable 'invalid'", result.ErrorMessage); } } From 3920ccb1e945aee3eaa1a84b15951305cdbd71e4 Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 18:32:02 -0400 Subject: [PATCH 10/13] Version bump --- StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj index 1c95a536..a6ed93d0 100644 --- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj +++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj @@ -8,7 +8,7 @@ app.manifest true ./Assets/Icon.ico - 2.5.3-dev.1 + 2.6.0-dev.1 $(Version) true true From 4d36d66f9687b56c527fe974cb949e845a2c6284 Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 22:16:07 -0400 Subject: [PATCH 11/13] Add variable slice support and prompts --- .../Inference/FileNameFormatProvider.cs | 80 +++++++++++++++++-- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs index 96cacf04..1422efa1 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -25,6 +25,8 @@ public partial class FileNameFormatProvider _substitutions ??= new Dictionary> { { "seed", () => GenerationParameters?.Seed.ToString() }, + { "prompt", () => GenerationParameters?.PositivePrompt }, + { "negative_prompt", () => GenerationParameters?.NegativePrompt }, { "model_name", () => GenerationParameters?.ModelName }, { "model_hash", () => GenerationParameters?.ModelHash }, { "width", () => GenerationParameters?.Width.ToString() }, @@ -47,11 +49,20 @@ public partial class FileNameFormatProvider var matches = regex.Matches(format); var variables = matches.Select(m => m.Groups[1].Value); - foreach (var variable in variables) + foreach (var variableText in variables) { - if (!Substitutions.ContainsKey(variable)) + try { - return new ValidationResult($"Unknown variable '{variable}'"); + var (variable, _) = ExtractVariableAndSlice(variableText); + + if (!Substitutions.ContainsKey(variable)) + { + return new ValidationResult($"Unknown variable '{variable}'"); + } + } + catch (Exception e) + { + return new ValidationResult($"Invalid variable '{variableText}': {e.Message}"); } } @@ -80,9 +91,38 @@ public partial class FileNameFormatProvider } // Now we're at start of the current match, add the variable part - var variable = result.Groups[1].Value; + var (variable, slice) = ExtractVariableAndSlice(result.Groups[1].Value); + var substitution = Substitutions[variable]; - parts.Add(FileNameFormatPart.FromSubstitution(Substitutions[variable])); + // Slice string if necessary + if (slice is not null) + { + parts.Add( + FileNameFormatPart.FromSubstitution(() => + { + var value = substitution(); + if (value is null) + return null; + + if (slice.End is null) + { + value = value[(slice.Start ?? 0)..]; + } + else + { + var length = + Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0); + value = value.Substring(slice.Start ?? 0, length); + } + + return value; + }) + ); + } + else + { + parts.Add(FileNameFormatPart.FromSubstitution(substitution)); + } currentIndex += result.Length; } @@ -110,10 +150,36 @@ public partial class FileNameFormatProvider }; } + /// + /// Extract variable and index from a combined string + /// + private static (string Variable, Slice? Slice) ExtractVariableAndSlice(string combined) + { + if (IndexRegex().Matches(combined).FirstOrDefault() is not { Success: true } match) + { + return (combined, null); + } + + // Variable is everything before the match + var variable = combined[..match.Groups[0].Index]; + + var start = match.Groups["start"].Value; + var end = match.Groups["end"].Value; + var step = match.Groups["step"].Value; + + var slice = new Slice( + string.IsNullOrEmpty(start) ? null : int.Parse(start), + string.IsNullOrEmpty(end) ? null : int.Parse(end), + string.IsNullOrEmpty(step) ? null : int.Parse(step) + ); + + return (variable, slice); + } + /// /// Regex for matching contents within a curly brace. /// - [GeneratedRegex(@"\{([a-z_]+)\}")] + [GeneratedRegex(@"\{([a-z_:\d\[\]]+)\}")] private static partial Regex BracketRegex(); /// @@ -121,4 +187,6 @@ public partial class FileNameFormatProvider /// [GeneratedRegex(@"\[(?:(?-?\d+)?)\:(?:(?-?\d+)?)?(?:\:(?-?\d+))?\]")] private static partial Regex IndexRegex(); + + private record Slice(int? Start, int? End, int? Step); } From 37147c220eacb1b95ea1d20dd6f470796e62cbec Mon Sep 17 00:00:00 2001 From: Ionite Date: Fri, 13 Oct 2023 15:27:17 -0400 Subject: [PATCH 12/13] Fix existing file name id appending --- .../ViewModels/Base/InferenceGenerationViewModelBase.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index d3b746f6..f1d0d11c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -152,8 +152,8 @@ public abstract partial class InferenceGenerationViewModelBase format = format.WithBatchPostFix(batchNum, batchTotal); } - var fileName = format.GetFileName() + ".png"; - var file = outputDir.JoinFile(fileName); + var fileName = format.GetFileName(); + var file = outputDir.JoinFile($"{fileName}.png"); // Until the file is free, keep adding _{i} to the end for (var i = 0; i < 100; i++) @@ -161,13 +161,14 @@ public abstract partial class InferenceGenerationViewModelBase if (!file.Exists) break; - file = outputDir.JoinFile($"{fileName}_{i + 1}"); + file = outputDir.JoinFile($"{fileName}_{i + 1}.png"); } // If that fails, append an 7-char uuid if (file.Exists) { - file = outputDir.JoinFile($"{fileName}_{Guid.NewGuid():N}"[..7]); + var uuid = Guid.NewGuid().ToString("N")[..7]; + file = outputDir.JoinFile($"{fileName}_{uuid}.png"); } await using var fileStream = file.Info.OpenWrite(); From 205218b01ab0b6a16aa92944c80069d49b18cfdd Mon Sep 17 00:00:00 2001 From: Ionite Date: Fri, 13 Oct 2023 15:28:56 -0400 Subject: [PATCH 13/13] Completion popup wrapping for arrow key navigation --- .../Controls/CodeCompletion/CompletionList.cs | 4 ++-- .../CodeCompletion/CompletionListBox.cs | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs index 9150209b..79f2b766 100644 --- a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs +++ b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs @@ -202,11 +202,11 @@ public class CompletionList : TemplatedControl { case Key.Down: e.Handled = true; - _listBox.SelectIndex(_listBox.SelectedIndex + 1); + _listBox.SelectNextIndexWithLoop(); break; case Key.Up: e.Handled = true; - _listBox.SelectIndex(_listBox.SelectedIndex - 1); + _listBox.SelectPreviousIndexWithLoop(); break; case Key.PageDown: e.Handled = true; diff --git a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionListBox.cs b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionListBox.cs index 4313be17..90f40a28 100644 --- a/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionListBox.cs +++ b/StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionListBox.cs @@ -91,6 +91,28 @@ public class CompletionListBox : ListBox SelectedIndex = -1; } + /// + /// Selects the next item. If the last item is already selected, selects the first item. + /// + public void SelectNextIndexWithLoop() + { + if (ItemCount <= 0) + return; + + SelectIndex((SelectedIndex + 1) % ItemCount); + } + + /// + /// Selects the previous item. If the first item is already selected, selects the last item. + /// + public void SelectPreviousIndexWithLoop() + { + if (ItemCount <= 0) + return; + + SelectIndex((SelectedIndex - 1 + ItemCount) % ItemCount); + } + /// /// Selects the item with the specified index and scrolls it into view. ///