using System.Globalization; using System.Reflection; using NSubstitute; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Core.Models.Tokens; using TextMateSharp.Grammars; using TextMateSharp.Registry; namespace StabilityMatrix.Tests.Avalonia; [TestClass] public class PromptTests { private ITokenizerProvider tokenizerProvider = null!; [TestInitialize] public void TestInitialize() { tokenizerProvider = Substitute.For(); var promptSyntaxFile = Assembly .GetExecutingAssembly() .GetManifestResourceStream("StabilityMatrix.Tests.ImagePrompt.tmLanguage.json")!; var registry = new Registry(new RegistryOptions(ThemeName.DarkPlus)); var grammar = registry.LoadGrammarFromStream(promptSyntaxFile); tokenizerProvider .TokenizeLine(Arg.Any()) .Returns(x => grammar.TokenizeLine(x.Arg())); } [TestMethod] public void TestPromptProcessedText() { var prompt = Prompt.FromRawText("test", tokenizerProvider); prompt.Process(); Assert.AreEqual("test", prompt.ProcessedText); } [TestMethod] public void TestPromptWeightParsing() { var prompt = Prompt.FromRawText("", tokenizerProvider); prompt.Process(); // Output should have no loras Assert.AreEqual("", prompt.ProcessedText); var network = prompt.ExtraNetworks[0]; Assert.AreEqual(PromptExtraNetworkType.Lora, network.Type); Assert.AreEqual("my_model", network.Name); Assert.AreEqual(1.5f, network.ModelWeight); } /// /// Tests that we can parse decimal numbers with different cultures /// [TestMethod] public void TestPromptWeightParsing_DecimalSeparatorCultures_ShouldParse() { var prompt = Prompt.FromRawText("", tokenizerProvider); // Cultures like de-DE use commas as decimal separators, check that we can parse those too ExecuteWithCulture(prompt.Process, CultureInfo.GetCultureInfo("de-DE")); // Output should have no loras Assert.AreEqual("", prompt.ProcessedText); var network = prompt.ExtraNetworks![0]; Assert.AreEqual(PromptExtraNetworkType.Lora, network.Type); Assert.AreEqual("my_model", network.Name); Assert.AreEqual(1.5f, network.ModelWeight); } private static T? ExecuteWithCulture(Func func, CultureInfo culture) { var result = default(T); var thread = new Thread(() => { result = func(); }) { CurrentCulture = culture }; thread.Start(); thread.Join(); return result; } private static void ExecuteWithCulture(Action func, CultureInfo culture) { var thread = new Thread(() => { func(); }) { CurrentCulture = culture }; thread.Start(); thread.Join(); } }