Ionite
1 year ago
5 changed files with 233 additions and 0 deletions
@ -0,0 +1,68 @@
|
||||
using System; |
||||
using System.Collections.Generic; |
||||
using System.Collections.Immutable; |
||||
using System.Diagnostics.CodeAnalysis; |
||||
using System.Linq; |
||||
|
||||
namespace StabilityMatrix.Avalonia.Models.Inference; |
||||
|
||||
public record FileNameFormat |
||||
{ |
||||
public string Template { get; } |
||||
|
||||
public string Prefix { get; set; } = ""; |
||||
|
||||
public string Postfix { get; set; } = ""; |
||||
|
||||
public IReadOnlyList<FileNameFormatPart> Parts { get; } |
||||
|
||||
private FileNameFormat(string template, IReadOnlyList<FileNameFormatPart> parts) |
||||
{ |
||||
Template = template; |
||||
Parts = parts; |
||||
} |
||||
|
||||
public FileNameFormat WithBatchPostFix(int current, int total) |
||||
{ |
||||
return this with { Postfix = Postfix + $" ({current}-{total})" }; |
||||
} |
||||
|
||||
public FileNameFormat WithGridPrefix() |
||||
{ |
||||
return this with { Prefix = Prefix + "Grid_" }; |
||||
} |
||||
|
||||
public string GetFileName() |
||||
{ |
||||
return Prefix |
||||
+ string.Join("", Parts.Select(p => p.Constant ?? p.Substitution?.Invoke() ?? "")) |
||||
+ Postfix; |
||||
} |
||||
|
||||
public static FileNameFormat Parse(string template, FileNameFormatProvider provider) |
||||
{ |
||||
provider.Validate(template); |
||||
var parts = provider.GetParts(template).ToImmutableArray(); |
||||
return new FileNameFormat(template, parts); |
||||
} |
||||
|
||||
public static bool TryParse( |
||||
string template, |
||||
FileNameFormatProvider provider, |
||||
[NotNullWhen(true)] out FileNameFormat? format |
||||
) |
||||
{ |
||||
try |
||||
{ |
||||
format = Parse(template, provider); |
||||
return true; |
||||
} |
||||
catch (ArgumentException) |
||||
{ |
||||
format = null; |
||||
return false; |
||||
} |
||||
} |
||||
|
||||
public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}"; |
||||
} |
@ -0,0 +1,5 @@
|
||||
using System; |
||||
|
||||
namespace StabilityMatrix.Avalonia.Models.Inference; |
||||
|
||||
public record FileNameFormatPart(string? Constant, Func<string?>? Substitution); |
@ -0,0 +1,111 @@
|
||||
using System; |
||||
using System.Collections.Generic; |
||||
using System.Linq; |
||||
using System.Text.RegularExpressions; |
||||
using StabilityMatrix.Core.Extensions; |
||||
using StabilityMatrix.Core.Models; |
||||
|
||||
namespace StabilityMatrix.Avalonia.Models.Inference; |
||||
|
||||
public partial class FileNameFormatProvider |
||||
{ |
||||
public GenerationParameters? GenerationParameters { get; init; } |
||||
|
||||
public InferenceProjectType? ProjectType { get; init; } |
||||
|
||||
public string? ProjectName { get; init; } |
||||
|
||||
private Dictionary<string, Func<string?>>? _substitutions; |
||||
|
||||
private Dictionary<string, Func<string?>> Substitutions => |
||||
_substitutions ??= new Dictionary<string, Func<string?>> |
||||
{ |
||||
{ "seed", () => GenerationParameters?.Seed.ToString() }, |
||||
{ "model_name", () => GenerationParameters?.ModelName }, |
||||
{ "model_hash", () => GenerationParameters?.ModelHash }, |
||||
{ "width", () => GenerationParameters?.Width.ToString() }, |
||||
{ "height", () => GenerationParameters?.Height.ToString() }, |
||||
{ "project_type", () => ProjectType?.GetStringValue() }, |
||||
{ "project_name", () => ProjectName }, |
||||
{ "date", () => DateTime.Now.ToString("yyyy-MM-dd") }, |
||||
{ "time", () => DateTime.Now.ToString("HH-mm-ss") } |
||||
}; |
||||
|
||||
public (int Current, int Total)? BatchInfo { get; init; } |
||||
|
||||
/// <summary> |
||||
/// Validate a format string |
||||
/// </summary> |
||||
public void Validate(string format) |
||||
{ |
||||
var regex = BracketRegex(); |
||||
var matches = regex.Matches(format); |
||||
var variables = matches.Select(m => m.Value[1..^1]).ToList(); |
||||
|
||||
foreach (var variable in variables) |
||||
{ |
||||
if (!Substitutions.ContainsKey(variable)) |
||||
{ |
||||
throw new ArgumentException($"Unknown variable '{variable}'"); |
||||
} |
||||
} |
||||
} |
||||
|
||||
public IEnumerable<FileNameFormatPart> GetParts(string template) |
||||
{ |
||||
var regex = BracketRegex(); |
||||
var matches = regex.Matches(template); |
||||
|
||||
var parts = new List<FileNameFormatPart>(); |
||||
|
||||
// Loop through all parts of the string, including matches and non-matches |
||||
var currentIndex = 0; |
||||
|
||||
foreach (var result in matches.Cast<Match>()) |
||||
{ |
||||
// If the match is not at the start of the string, add a constant part |
||||
if (result.Index != currentIndex) |
||||
{ |
||||
var constant = template[currentIndex..result.Index]; |
||||
parts.Add(new FileNameFormatPart(constant, null)); |
||||
|
||||
currentIndex += constant.Length; |
||||
} |
||||
|
||||
var variable = result.Value[1..^1]; |
||||
parts.Add(new FileNameFormatPart(null, Substitutions[variable])); |
||||
|
||||
currentIndex += result.Length; |
||||
} |
||||
|
||||
// Add remaining as constant |
||||
if (currentIndex != template.Length) |
||||
{ |
||||
var constant = template[currentIndex..]; |
||||
parts.Add(new FileNameFormatPart(constant, null)); |
||||
} |
||||
|
||||
return parts; |
||||
} |
||||
|
||||
/// <summary> |
||||
/// Return a string substituting the variables in the format string |
||||
/// </summary> |
||||
private string? GetSubstitution(string variable) |
||||
{ |
||||
return variable switch |
||||
{ |
||||
"seed" => GenerationParameters.Seed.ToString(), |
||||
"model_name" => GenerationParameters.ModelName, |
||||
"model_hash" => GenerationParameters.ModelHash, |
||||
"width" => GenerationParameters.Width.ToString(), |
||||
"height" => GenerationParameters.Height.ToString(), |
||||
"date" => DateTime.Now.ToString("yyyy-MM-dd"), |
||||
"time" => DateTime.Now.ToString("HH-mm-ss"), |
||||
_ => throw new ArgumentOutOfRangeException(nameof(variable), variable, null) |
||||
}; |
||||
} |
||||
|
||||
[GeneratedRegex(@"\{[a-z_]+\}")]
|
||||
private static partial Regex BracketRegex(); |
||||
} |
@ -0,0 +1,25 @@
|
||||
using StabilityMatrix.Avalonia.Models.Inference; |
||||
|
||||
namespace StabilityMatrix.Tests.Avalonia; |
||||
|
||||
[TestClass] |
||||
public class FileNameFormatProviderTests |
||||
{ |
||||
[TestMethod] |
||||
public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow() |
||||
{ |
||||
var provider = new FileNameFormatProvider(); |
||||
|
||||
provider.Validate("{date}_{time}-{model_name}-{seed}"); |
||||
} |
||||
|
||||
[TestMethod] |
||||
public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow() |
||||
{ |
||||
var provider = new FileNameFormatProvider(); |
||||
|
||||
Assert.ThrowsException<ArgumentException>( |
||||
() => provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}") |
||||
); |
||||
} |
||||
} |
@ -0,0 +1,24 @@
|
||||
using StabilityMatrix.Avalonia.Models; |
||||
using StabilityMatrix.Avalonia.Models.Inference; |
||||
using StabilityMatrix.Core.Models; |
||||
|
||||
namespace StabilityMatrix.Tests.Avalonia; |
||||
|
||||
[TestClass] |
||||
public class FileNameFormatTests |
||||
{ |
||||
[TestMethod] |
||||
public void TestFileNameFormatParse() |
||||
{ |
||||
var provider = new FileNameFormatProvider |
||||
{ |
||||
GenerationParameters = new GenerationParameters { Seed = 123 }, |
||||
ProjectName = "uwu", |
||||
ProjectType = InferenceProjectType.TextToImage, |
||||
}; |
||||
|
||||
var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider); |
||||
|
||||
Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName()); |
||||
} |
||||
} |
Loading…
Reference in new issue