using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Text; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.Tokens; using StabilityMatrix.Core.Services; using TextMateSharp.Grammars; namespace StabilityMatrix.Avalonia.Models.Inference; public record Prompt { public required string RawText { get; init; } public required ITokenizeLineResult TokenizeResult { get; init; } [MemberNotNullWhen(true, nameof(ExtraNetworks), nameof(ProcessedText))] public bool IsProcessed { get; private set; } /// /// Extra networks specified in prompt. /// public IReadOnlyList? ExtraNetworks { get; private set; } /// /// Processed text suitable for sending to inference backend. /// This excludes extra network (i.e. LORA) tokens. /// public string? ProcessedText { get; private set; } [MemberNotNull(nameof(ExtraNetworks), nameof(ProcessedText))] public void Process() { if (IsProcessed) return; var (promptExtraNetworks, processedText) = GetExtraNetworks(); ExtraNetworks = promptExtraNetworks; ProcessedText = processedText; } /// /// Verifies that extra network files exists locally. /// /// Thrown if a filename does not exist public void ValidateExtraNetworks(IModelIndexService indexService) { GetExtraNetworks(indexService); } /// /// Get ExtraNetworks as local model files and weights. /// public IEnumerable<( LocalModelFile ModelFile, double? ModelWeight, double? ClipWeight )> GetExtraNetworksAsLocalModels(IModelIndexService indexService) { if (ExtraNetworks is null) { throw new InvalidOperationException( "Prompt must be processed before calling GetExtraNetworksAsLocalModels" ); } foreach (var network in ExtraNetworks) { var sharedFolderType = network.Type.ConvertTo(); if (!indexService.ModelIndex.TryGetValue(sharedFolderType, out var modelList)) { throw new ApplicationException($"Model {network.Name} does not exist in index"); } var localModel = modelList.FirstOrDefault( m => m.FileNameWithoutExtension == network.Name ); if (localModel == null) { throw new ApplicationException($"Model {network.Name} does not exist in index"); } yield return (localModel, network.ModelWeight, network.ClipWeight); } } private int GetSafeEndIndex(int index) { return Math.Min(index, RawText.Length); } private (List promptExtraNetworks, string processedText) GetExtraNetworks( IModelIndexService? indexService = null ) { // Parse tokens "meta.structure.network.prompt" // "<": "punctuation.definition.network.begin.prompt" // (type): "meta.embedded.network.type.prompt" // ":": "punctuation.separator.variable.prompt" // (content): "meta.embedded.network.model.prompt" // ">": "punctuation.definition.network.end.prompt" using var tokens = TokenizeResult.Tokens.Cast().GetEnumerator(); // Store non-network tokens var outputTokens = new Stack(); var outputText = new Stack(); // Store extra networks var promptExtraNetworks = new List(); while (tokens.MoveNext()) { var currentToken = tokens.Current; // For any invalid syntax, throw if (currentToken.Scopes.Any(s => s.Contains("invalid.illegal"))) { // Generic throw new PromptSyntaxError( "Invalid Token", currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } // Comments - ignore if (currentToken.Scopes.Any(s => s.Contains("comment.line"))) { continue; } // Find start of network token, until then just add to output if (!currentToken.Scopes.Contains("punctuation.definition.network.begin.prompt")) { // Normal tags - Push to output outputTokens.Push(currentToken); outputText.Push( RawText[currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)] ); continue; } // Expect next token to be network type if (!tokens.MoveNext()) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } currentToken = tokens.Current; if (!currentToken.Scopes.Contains("meta.embedded.network.type.prompt")) { throw PromptSyntaxError.Network_ExpectedType( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } var networkType = RawText[ currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex) ]; // Match network type var parsedNetworkType = networkType switch { "lora" => PromptExtraNetworkType.Lora, "lyco" => PromptExtraNetworkType.LyCORIS, "embedding" => PromptExtraNetworkType.Embedding, _ => throw PromptValidationError.Network_UnknownType( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ) }; // Skip colon token if (!tokens.MoveNext()) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } currentToken = tokens.Current; // Ensure next token is colon if (!currentToken.Scopes.Contains("punctuation.separator.variable.prompt")) { throw PromptSyntaxError.Network_ExpectedSeparator( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } // Get model name if (!tokens.MoveNext()) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } currentToken = tokens.Current; if (!currentToken.Scopes.Contains("meta.embedded.network.model.prompt")) { throw PromptSyntaxError.Network_ExpectedName( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } var modelName = RawText[ currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex) ]; // If index service provided, validate model name if (indexService != null) { var localModelList = indexService.ModelIndex.GetOrAdd( parsedNetworkType.ConvertTo() ); var localModel = localModelList.FirstOrDefault( m => Path.GetFileNameWithoutExtension(m.FileName) == modelName ); if (localModel == null) { throw PromptValidationError.Network_UnknownModel( modelName, parsedNetworkType, currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } } // Skip another colon token if (!tokens.MoveNext()) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } currentToken = tokens.Current; double? weight = null; // If its a ending token instead, we can end here, otherwise keep parsing for weight if (!currentToken.Scopes.Contains("punctuation.definition.network.end.prompt")) { // Ensure next token is colon if (!currentToken.Scopes.Contains("punctuation.separator.variable.prompt")) { throw PromptSyntaxError.Network_ExpectedSeparator( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } // Get model weight if (!tokens.MoveNext()) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } currentToken = tokens.Current; if (!currentToken.Scopes.Contains("constant.numeric")) { throw PromptSyntaxError.Network_ExpectedWeight( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } var modelWeight = RawText[ currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex) ]; // Convert to double if (!double.TryParse(modelWeight, out var weightValue)) { throw PromptValidationError.Network_InvalidWeight( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } weight = weightValue; // Expect end if (!tokens.MoveNext()) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } currentToken = tokens.Current; if (!currentToken.Scopes.Contains("punctuation.definition.network.end.prompt")) { throw PromptSyntaxError.UnexpectedEndOfText( currentToken.StartIndex, GetSafeEndIndex(currentToken.EndIndex) ); } } // For embeddings, we add to the prompt, and not to the extra networks list if (parsedNetworkType is PromptExtraNetworkType.Embedding) { // Push to output in Comfy format // -> embedding:model // -> (embedding:model:weight) outputTokens.Push(currentToken); outputText.Push( weight is null ? $"embedding:{modelName}" : $"(embedding:{modelName}:{weight:F2})" ); } // Cleanups for separate extra networks else { // If last entry on stack is a separator, remove it if ( outputTokens.TryPeek(out var lastToken2) && lastToken2.Scopes.Contains("punctuation.separator.variable.prompt") ) { outputTokens.Pop(); outputText.Pop(); } // Add to output promptExtraNetworks.Add( new PromptExtraNetwork { Type = parsedNetworkType, Name = modelName, ModelWeight = weight } ); } } var processedText = string.Join("", outputText.Reverse()); return (promptExtraNetworks, processedText); } public string GetDebugText() { var builder = new StringBuilder(); foreach (var token in TokenizeResult.Tokens) { // Get token text var text = RawText[token.StartIndex..Math.Min(token.EndIndex, RawText.Length - 1)]; // Format scope var scopeStr = string.Join( ", ", token.Scopes .Where(s => s != "source.prompt") .Select( s => s.EndsWith(".prompt") ? s.Remove(s.LastIndexOf(".prompt", StringComparison.Ordinal)) : s ) ); builder.AppendLine($"{text.ToRepr()} ({token.StartIndex}, {token.EndIndex})"); builder.AppendLine($" └─ {scopeStr}"); } return builder.ToString(); } public static Prompt FromRawText(string text, ITokenizerProvider tokenizer) { using var _ = new CodeTimer(); var result = tokenizer.TokenizeLine(text); return new Prompt { RawText = text, TokenizeResult = result }; } }