Browse Source

Add file name formatting frameworks

pull/240/head
Ionite 1 year ago
parent
commit
a40e4e5a98
No known key found for this signature in database
  1. 68
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs
  2. 5
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs
  3. 111
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs
  4. 25
      StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs
  5. 24
      StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs

68
StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs

@ -0,0 +1,68 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
namespace StabilityMatrix.Avalonia.Models.Inference;
public record FileNameFormat
{
public string Template { get; }
public string Prefix { get; set; } = "";
public string Postfix { get; set; } = "";
public IReadOnlyList<FileNameFormatPart> Parts { get; }
private FileNameFormat(string template, IReadOnlyList<FileNameFormatPart> parts)
{
Template = template;
Parts = parts;
}
public FileNameFormat WithBatchPostFix(int current, int total)
{
return this with { Postfix = Postfix + $" ({current}-{total})" };
}
public FileNameFormat WithGridPrefix()
{
return this with { Prefix = Prefix + "Grid_" };
}
public string GetFileName()
{
return Prefix
+ string.Join("", Parts.Select(p => p.Constant ?? p.Substitution?.Invoke() ?? ""))
+ Postfix;
}
public static FileNameFormat Parse(string template, FileNameFormatProvider provider)
{
provider.Validate(template);
var parts = provider.GetParts(template).ToImmutableArray();
return new FileNameFormat(template, parts);
}
public static bool TryParse(
string template,
FileNameFormatProvider provider,
[NotNullWhen(true)] out FileNameFormat? format
)
{
try
{
format = Parse(template, provider);
return true;
}
catch (ArgumentException)
{
format = null;
return false;
}
}
public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}";
}

5
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs

@ -0,0 +1,5 @@
using System;
namespace StabilityMatrix.Avalonia.Models.Inference;
public record FileNameFormatPart(string? Constant, Func<string?>? Substitution);

111
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs

@ -0,0 +1,111 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Models.Inference;
public partial class FileNameFormatProvider
{
public GenerationParameters? GenerationParameters { get; init; }
public InferenceProjectType? ProjectType { get; init; }
public string? ProjectName { get; init; }
private Dictionary<string, Func<string?>>? _substitutions;
private Dictionary<string, Func<string?>> Substitutions =>
_substitutions ??= new Dictionary<string, Func<string?>>
{
{ "seed", () => GenerationParameters?.Seed.ToString() },
{ "model_name", () => GenerationParameters?.ModelName },
{ "model_hash", () => GenerationParameters?.ModelHash },
{ "width", () => GenerationParameters?.Width.ToString() },
{ "height", () => GenerationParameters?.Height.ToString() },
{ "project_type", () => ProjectType?.GetStringValue() },
{ "project_name", () => ProjectName },
{ "date", () => DateTime.Now.ToString("yyyy-MM-dd") },
{ "time", () => DateTime.Now.ToString("HH-mm-ss") }
};
public (int Current, int Total)? BatchInfo { get; init; }
/// <summary>
/// Validate a format string
/// </summary>
public void Validate(string format)
{
var regex = BracketRegex();
var matches = regex.Matches(format);
var variables = matches.Select(m => m.Value[1..^1]).ToList();
foreach (var variable in variables)
{
if (!Substitutions.ContainsKey(variable))
{
throw new ArgumentException($"Unknown variable '{variable}'");
}
}
}
public IEnumerable<FileNameFormatPart> GetParts(string template)
{
var regex = BracketRegex();
var matches = regex.Matches(template);
var parts = new List<FileNameFormatPart>();
// Loop through all parts of the string, including matches and non-matches
var currentIndex = 0;
foreach (var result in matches.Cast<Match>())
{
// If the match is not at the start of the string, add a constant part
if (result.Index != currentIndex)
{
var constant = template[currentIndex..result.Index];
parts.Add(new FileNameFormatPart(constant, null));
currentIndex += constant.Length;
}
var variable = result.Value[1..^1];
parts.Add(new FileNameFormatPart(null, Substitutions[variable]));
currentIndex += result.Length;
}
// Add remaining as constant
if (currentIndex != template.Length)
{
var constant = template[currentIndex..];
parts.Add(new FileNameFormatPart(constant, null));
}
return parts;
}
/// <summary>
/// Return a string substituting the variables in the format string
/// </summary>
private string? GetSubstitution(string variable)
{
return variable switch
{
"seed" => GenerationParameters.Seed.ToString(),
"model_name" => GenerationParameters.ModelName,
"model_hash" => GenerationParameters.ModelHash,
"width" => GenerationParameters.Width.ToString(),
"height" => GenerationParameters.Height.ToString(),
"date" => DateTime.Now.ToString("yyyy-MM-dd"),
"time" => DateTime.Now.ToString("HH-mm-ss"),
_ => throw new ArgumentOutOfRangeException(nameof(variable), variable, null)
};
}
[GeneratedRegex(@"\{[a-z_]+\}")]
private static partial Regex BracketRegex();
}

25
StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs

@ -0,0 +1,25 @@
using StabilityMatrix.Avalonia.Models.Inference;
namespace StabilityMatrix.Tests.Avalonia;
[TestClass]
public class FileNameFormatProviderTests
{
[TestMethod]
public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow()
{
var provider = new FileNameFormatProvider();
provider.Validate("{date}_{time}-{model_name}-{seed}");
}
[TestMethod]
public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow()
{
var provider = new FileNameFormatProvider();
Assert.ThrowsException<ArgumentException>(
() => provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}")
);
}
}

24
StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs

@ -0,0 +1,24 @@
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Tests.Avalonia;
[TestClass]
public class FileNameFormatTests
{
[TestMethod]
public void TestFileNameFormatParse()
{
var provider = new FileNameFormatProvider
{
GenerationParameters = new GenerationParameters { Seed = 123 },
ProjectName = "uwu",
ProjectType = InferenceProjectType.TextToImage,
};
var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider);
Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName());
}
}
Loading…
Cancel
Save