From 4d36d66f9687b56c527fe974cb949e845a2c6284 Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 12 Oct 2023 22:16:07 -0400 Subject: [PATCH] Add variable slice support and prompts --- .../Inference/FileNameFormatProvider.cs | 80 +++++++++++++++++-- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs index 96cacf04..1422efa1 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -25,6 +25,8 @@ public partial class FileNameFormatProvider _substitutions ??= new Dictionary> { { "seed", () => GenerationParameters?.Seed.ToString() }, + { "prompt", () => GenerationParameters?.PositivePrompt }, + { "negative_prompt", () => GenerationParameters?.NegativePrompt }, { "model_name", () => GenerationParameters?.ModelName }, { "model_hash", () => GenerationParameters?.ModelHash }, { "width", () => GenerationParameters?.Width.ToString() }, @@ -47,11 +49,20 @@ public partial class FileNameFormatProvider var matches = regex.Matches(format); var variables = matches.Select(m => m.Groups[1].Value); - foreach (var variable in variables) + foreach (var variableText in variables) { - if (!Substitutions.ContainsKey(variable)) + try { - return new ValidationResult($"Unknown variable '{variable}'"); + var (variable, _) = ExtractVariableAndSlice(variableText); + + if (!Substitutions.ContainsKey(variable)) + { + return new ValidationResult($"Unknown variable '{variable}'"); + } + } + catch (Exception e) + { + return new ValidationResult($"Invalid variable '{variableText}': {e.Message}"); } } @@ -80,9 +91,38 @@ public partial class FileNameFormatProvider } // Now we're at start of the current match, add the variable part - var variable = result.Groups[1].Value; + var (variable, slice) = ExtractVariableAndSlice(result.Groups[1].Value); + var substitution = Substitutions[variable]; - parts.Add(FileNameFormatPart.FromSubstitution(Substitutions[variable])); + // Slice string if necessary + if (slice is not null) + { + parts.Add( + FileNameFormatPart.FromSubstitution(() => + { + var value = substitution(); + if (value is null) + return null; + + if (slice.End is null) + { + value = value[(slice.Start ?? 0)..]; + } + else + { + var length = + Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0); + value = value.Substring(slice.Start ?? 0, length); + } + + return value; + }) + ); + } + else + { + parts.Add(FileNameFormatPart.FromSubstitution(substitution)); + } currentIndex += result.Length; } @@ -110,10 +150,36 @@ public partial class FileNameFormatProvider }; } + /// + /// Extract variable and index from a combined string + /// + private static (string Variable, Slice? Slice) ExtractVariableAndSlice(string combined) + { + if (IndexRegex().Matches(combined).FirstOrDefault() is not { Success: true } match) + { + return (combined, null); + } + + // Variable is everything before the match + var variable = combined[..match.Groups[0].Index]; + + var start = match.Groups["start"].Value; + var end = match.Groups["end"].Value; + var step = match.Groups["step"].Value; + + var slice = new Slice( + string.IsNullOrEmpty(start) ? null : int.Parse(start), + string.IsNullOrEmpty(end) ? null : int.Parse(end), + string.IsNullOrEmpty(step) ? null : int.Parse(step) + ); + + return (variable, slice); + } + /// /// Regex for matching contents within a curly brace. /// - [GeneratedRegex(@"\{([a-z_]+)\}")] + [GeneratedRegex(@"\{([a-z_:\d\[\]]+)\}")] private static partial Regex BracketRegex(); /// @@ -121,4 +187,6 @@ public partial class FileNameFormatProvider /// [GeneratedRegex(@"\[(?:(?-?\d+)?)\:(?:(?-?\d+)?)?(?:\:(?-?\d+))?\]")] private static partial Regex IndexRegex(); + + private record Slice(int? Start, int? End, int? Step); }