diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs new file mode 100644 index 00000000..28e9a35a --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormat +{ + public string Template { get; } + + public string Prefix { get; set; } = ""; + + public string Postfix { get; set; } = ""; + + public IReadOnlyList Parts { get; } + + private FileNameFormat(string template, IReadOnlyList parts) + { + Template = template; + Parts = parts; + } + + public FileNameFormat WithBatchPostFix(int current, int total) + { + return this with { Postfix = Postfix + $" ({current}-{total})" }; + } + + public FileNameFormat WithGridPrefix() + { + return this with { Prefix = Prefix + "Grid_" }; + } + + public string GetFileName() + { + return Prefix + + string.Join("", Parts.Select(p => p.Constant ?? p.Substitution?.Invoke() ?? "")) + + Postfix; + } + + public static FileNameFormat Parse(string template, FileNameFormatProvider provider) + { + provider.Validate(template); + var parts = provider.GetParts(template).ToImmutableArray(); + return new FileNameFormat(template, parts); + } + + public static bool TryParse( + string template, + FileNameFormatProvider provider, + [NotNullWhen(true)] out FileNameFormat? format + ) + { + try + { + format = Parse(template, provider); + return true; + } + catch (ArgumentException) + { + format = null; + return false; + } + } + + public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}"; +} diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs new file mode 100644 index 00000000..bfbcc8d9 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs @@ -0,0 +1,5 @@ +using System; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormatPart(string? Constant, Func? Substitution); diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs new file mode 100644 index 00000000..e6ecc563 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public partial class FileNameFormatProvider +{ + public GenerationParameters? GenerationParameters { get; init; } + + public InferenceProjectType? ProjectType { get; init; } + + public string? ProjectName { get; init; } + + private Dictionary>? _substitutions; + + private Dictionary> Substitutions => + _substitutions ??= new Dictionary> + { + { "seed", () => GenerationParameters?.Seed.ToString() }, + { "model_name", () => GenerationParameters?.ModelName }, + { "model_hash", () => GenerationParameters?.ModelHash }, + { "width", () => GenerationParameters?.Width.ToString() }, + { "height", () => GenerationParameters?.Height.ToString() }, + { "project_type", () => ProjectType?.GetStringValue() }, + { "project_name", () => ProjectName }, + { "date", () => DateTime.Now.ToString("yyyy-MM-dd") }, + { "time", () => DateTime.Now.ToString("HH-mm-ss") } + }; + + public (int Current, int Total)? BatchInfo { get; init; } + + /// + /// Validate a format string + /// + public void Validate(string format) + { + var regex = BracketRegex(); + var matches = regex.Matches(format); + var variables = matches.Select(m => m.Value[1..^1]).ToList(); + + foreach (var variable in variables) + { + if (!Substitutions.ContainsKey(variable)) + { + throw new ArgumentException($"Unknown variable '{variable}'"); + } + } + } + + public IEnumerable GetParts(string template) + { + var regex = BracketRegex(); + var matches = regex.Matches(template); + + var parts = new List(); + + // Loop through all parts of the string, including matches and non-matches + var currentIndex = 0; + + foreach (var result in matches.Cast()) + { + // If the match is not at the start of the string, add a constant part + if (result.Index != currentIndex) + { + var constant = template[currentIndex..result.Index]; + parts.Add(new FileNameFormatPart(constant, null)); + + currentIndex += constant.Length; + } + + var variable = result.Value[1..^1]; + parts.Add(new FileNameFormatPart(null, Substitutions[variable])); + + currentIndex += result.Length; + } + + // Add remaining as constant + if (currentIndex != template.Length) + { + var constant = template[currentIndex..]; + parts.Add(new FileNameFormatPart(constant, null)); + } + + return parts; + } + + /// + /// Return a string substituting the variables in the format string + /// + private string? GetSubstitution(string variable) + { + return variable switch + { + "seed" => GenerationParameters.Seed.ToString(), + "model_name" => GenerationParameters.ModelName, + "model_hash" => GenerationParameters.ModelHash, + "width" => GenerationParameters.Width.ToString(), + "height" => GenerationParameters.Height.ToString(), + "date" => DateTime.Now.ToString("yyyy-MM-dd"), + "time" => DateTime.Now.ToString("HH-mm-ss"), + _ => throw new ArgumentOutOfRangeException(nameof(variable), variable, null) + }; + } + + [GeneratedRegex(@"\{[a-z_]+\}")] + private static partial Regex BracketRegex(); +} diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs new file mode 100644 index 00000000..cdf7fdfa --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs @@ -0,0 +1,25 @@ +using StabilityMatrix.Avalonia.Models.Inference; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class FileNameFormatProviderTests +{ + [TestMethod] + public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow() + { + var provider = new FileNameFormatProvider(); + + provider.Validate("{date}_{time}-{model_name}-{seed}"); + } + + [TestMethod] + public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow() + { + var provider = new FileNameFormatProvider(); + + Assert.ThrowsException( + () => provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}") + ); + } +} diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs new file mode 100644 index 00000000..0da1eb84 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs @@ -0,0 +1,24 @@ +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class FileNameFormatTests +{ + [TestMethod] + public void TestFileNameFormatParse() + { + var provider = new FileNameFormatProvider + { + GenerationParameters = new GenerationParameters { Seed = 123 }, + ProjectName = "uwu", + ProjectType = InferenceProjectType.TextToImage, + }; + + var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider); + + Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName()); + } +}