From 3664b1d3ad57cf4d167464c1a8199f6e220618ee Mon Sep 17 00:00:00 2001 From: Ionite Date: Fri, 6 Oct 2023 16:07:04 -0400 Subject: [PATCH] Add prompt tests --- .../Models/TagCompletion/TokenizerProvider.cs | 15 ++- StabilityMatrix.Tests/Avalonia/PromptTests.cs | 114 ++++++++++++++++++ 2 files changed, 121 insertions(+), 8 deletions(-) create mode 100644 StabilityMatrix.Tests/Avalonia/PromptTests.cs diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs index 3a958878..74b22d44 100644 --- a/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs @@ -8,16 +8,15 @@ namespace StabilityMatrix.Avalonia.Models.TagCompletion; public class TokenizerProvider : ITokenizerProvider { private readonly Registry registry = new(new RegistryOptions(ThemeName.DarkPlus)); - private IGrammar grammar; - - public TokenizerProvider() - { - SetPromptGrammar(); - } - + private IGrammar? grammar; + /// public ITokenizeLineResult TokenizeLine(string lineText) { + if (grammar is null) + { + SetPromptGrammar(); + } return grammar.TokenizeLine(lineText); } @@ -27,7 +26,7 @@ public class TokenizerProvider : ITokenizerProvider using var stream = Assets.ImagePromptLanguageJson.Open(); grammar = registry.LoadGrammarFromStream(stream); } - + public void SetGrammar(string scopeName) { grammar = registry.LoadGrammar(scopeName); diff --git a/StabilityMatrix.Tests/Avalonia/PromptTests.cs b/StabilityMatrix.Tests/Avalonia/PromptTests.cs new file mode 100644 index 00000000..60b556a0 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/PromptTests.cs @@ -0,0 +1,114 @@ +using System.Globalization; +using System.Reflection; +using NSubstitute; +using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Models.TagCompletion; +using StabilityMatrix.Core.Models.Tokens; +using TextMateSharp.Grammars; +using TextMateSharp.Registry; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class PromptTests +{ + private ITokenizerProvider tokenizerProvider = null!; + + [TestInitialize] + public void TestInitialize() + { + tokenizerProvider = Substitute.For(); + + var promptSyntaxFile = Assembly + .GetExecutingAssembly() + .GetManifestResourceStream("StabilityMatrix.Tests.ImagePrompt.tmLanguage.json")!; + + var registry = new Registry(new RegistryOptions(ThemeName.DarkPlus)); + var grammar = registry.LoadGrammarFromStream(promptSyntaxFile); + + tokenizerProvider + .TokenizeLine(Arg.Any()) + .Returns(x => grammar.TokenizeLine(x.Arg())); + } + + [TestMethod] + public void TestPromptProcessedText() + { + var prompt = Prompt.FromRawText("test", tokenizerProvider); + + prompt.Process(); + + Assert.AreEqual("test", prompt.ProcessedText); + } + + [TestMethod] + public void TestPromptWeightParsing() + { + var prompt = Prompt.FromRawText("", tokenizerProvider); + + prompt.Process(); + + // Output should have no loras + Assert.AreEqual("", prompt.ProcessedText); + + var network = prompt.ExtraNetworks[0]; + + Assert.AreEqual(PromptExtraNetworkType.Lora, network.Type); + Assert.AreEqual("my_model", network.Name); + Assert.AreEqual(1.5f, network.ModelWeight); + } + + /// + /// Tests that we can parse decimal numbers with different cultures + /// + [TestMethod] + public void TestPromptWeightParsing_DecimalSeparatorCultures_ShouldParse() + { + var prompt = Prompt.FromRawText("", tokenizerProvider); + + // Cultures like de-DE use commas as decimal separators, check that we can parse those too + ExecuteWithCulture(prompt.Process, CultureInfo.GetCultureInfo("de-DE")); + + // Output should have no loras + Assert.AreEqual("", prompt.ProcessedText); + + var network = prompt.ExtraNetworks![0]; + + Assert.AreEqual(PromptExtraNetworkType.Lora, network.Type); + Assert.AreEqual("my_model", network.Name); + Assert.AreEqual(1.5f, network.ModelWeight); + } + + private static T? ExecuteWithCulture(Func func, CultureInfo culture) + { + var result = default(T); + + var thread = new Thread(() => + { + result = func(); + }) + { + CurrentCulture = culture + }; + + thread.Start(); + thread.Join(); + + return result; + } + + private static void ExecuteWithCulture(Action func, CultureInfo culture) + { + var thread = new Thread(() => + { + func(); + }) + { + CurrentCulture = culture + }; + + thread.Start(); + thread.Join(); + } +}