using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.Tokens;
using StabilityMatrix.Core.Services;
using TextMateSharp.Grammars;
namespace StabilityMatrix.Avalonia.Models.Inference;
public record Prompt
{
public required string RawText { get; init; }
public required ITokenizeLineResult TokenizeResult { get; init; }
[MemberNotNullWhen(true, nameof(ExtraNetworks), nameof(ProcessedText))]
public bool IsProcessed { get; private set; }
///
/// Extra networks specified in prompt.
///
public IReadOnlyList? ExtraNetworks { get; private set; }
///
/// Processed text suitable for sending to inference backend.
/// This excludes extra network (i.e. LORA) tokens.
///
public string? ProcessedText { get; private set; }
[MemberNotNull(nameof(ExtraNetworks), nameof(ProcessedText))]
public void Process()
{
if (IsProcessed)
return;
var (promptExtraNetworks, processedText) = GetExtraNetworks();
ExtraNetworks = promptExtraNetworks;
ProcessedText = processedText;
}
///
/// Verifies that extra network files exists locally.
///
/// Thrown if a filename does not exist
public void ValidateExtraNetworks(IModelIndexService indexService)
{
GetExtraNetworks(indexService);
}
///
/// Get ExtraNetworks as local model files and weights.
///
public IEnumerable<(
LocalModelFile ModelFile,
double? ModelWeight,
double? ClipWeight
)> GetExtraNetworksAsLocalModels(IModelIndexService indexService)
{
if (ExtraNetworks is null)
{
throw new InvalidOperationException(
"Prompt must be processed before calling GetExtraNetworksAsLocalModels"
);
}
foreach (var network in ExtraNetworks)
{
var sharedFolderType = network.Type.ConvertTo();
if (!indexService.ModelIndex.TryGetValue(sharedFolderType, out var modelList))
{
throw new ApplicationException($"Model {network.Name} does not exist in index");
}
var localModel = modelList.FirstOrDefault(
m => m.FileNameWithoutExtension == network.Name
);
if (localModel == null)
{
throw new ApplicationException($"Model {network.Name} does not exist in index");
}
yield return (localModel, network.ModelWeight, network.ClipWeight);
}
}
private int GetSafeEndIndex(int index)
{
return Math.Min(index, RawText.Length);
}
private (List promptExtraNetworks, string processedText) GetExtraNetworks(
IModelIndexService? indexService = null
)
{
// Parse tokens "meta.structure.network.prompt"
// "<": "punctuation.definition.network.begin.prompt"
// (type): "meta.embedded.network.type.prompt"
// ":": "punctuation.separator.variable.prompt"
// (content): "meta.embedded.network.model.prompt"
// ">": "punctuation.definition.network.end.prompt"
using var tokens = TokenizeResult.Tokens.Cast().GetEnumerator();
// Store non-network tokens
var outputTokens = new Stack();
var outputText = new Stack();
// Store extra networks
var promptExtraNetworks = new List();
while (tokens.MoveNext())
{
var currentToken = tokens.Current;
// For any invalid syntax, throw
if (currentToken.Scopes.Any(s => s.Contains("invalid.illegal")))
{
// Generic
throw new PromptSyntaxError(
"Invalid Token",
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
// Comments - ignore
if (currentToken.Scopes.Any(s => s.Contains("comment.line")))
{
continue;
}
// Find start of network token, until then just add to output
if (!currentToken.Scopes.Contains("punctuation.definition.network.begin.prompt"))
{
// Normal tags - Push to output
outputTokens.Push(currentToken);
outputText.Push(
RawText[currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)]
);
continue;
}
// Expect next token to be network type
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
currentToken = tokens.Current;
if (!currentToken.Scopes.Contains("meta.embedded.network.type.prompt"))
{
throw PromptSyntaxError.Network_ExpectedType(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
var networkType = RawText[
currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)
];
// Match network type
var parsedNetworkType = networkType switch
{
"lora" => PromptExtraNetworkType.Lora,
"lyco" => PromptExtraNetworkType.LyCORIS,
"embedding" => PromptExtraNetworkType.Embedding,
_
=> throw PromptValidationError.Network_UnknownType(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
)
};
// Skip colon token
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
currentToken = tokens.Current;
// Ensure next token is colon
if (!currentToken.Scopes.Contains("punctuation.separator.variable.prompt"))
{
throw PromptSyntaxError.Network_ExpectedSeparator(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
// Get model name
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
currentToken = tokens.Current;
if (!currentToken.Scopes.Contains("meta.embedded.network.model.prompt"))
{
throw PromptSyntaxError.Network_ExpectedName(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
var modelName = RawText[
currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)
];
// If index service provided, validate model name
if (indexService != null)
{
var localModelList = indexService.ModelIndex.GetOrAdd(
parsedNetworkType.ConvertTo()
);
var localModel = localModelList.FirstOrDefault(
m => Path.GetFileNameWithoutExtension(m.FileName) == modelName
);
if (localModel == null)
{
throw PromptValidationError.Network_UnknownModel(
modelName,
parsedNetworkType,
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
}
// Skip another colon token
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
currentToken = tokens.Current;
double? weight = null;
// If its a ending token instead, we can end here, otherwise keep parsing for weight
if (!currentToken.Scopes.Contains("punctuation.definition.network.end.prompt"))
{
// Ensure next token is colon
if (!currentToken.Scopes.Contains("punctuation.separator.variable.prompt"))
{
throw PromptSyntaxError.Network_ExpectedSeparator(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
// Get model weight
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
currentToken = tokens.Current;
if (!currentToken.Scopes.Contains("constant.numeric"))
{
throw PromptSyntaxError.Network_ExpectedWeight(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
var modelWeight = RawText[
currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)
];
// Convert to double
if (!double.TryParse(modelWeight, out var weightValue))
{
throw PromptValidationError.Network_InvalidWeight(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
weight = weightValue;
// Expect end
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
currentToken = tokens.Current;
if (!currentToken.Scopes.Contains("punctuation.definition.network.end.prompt"))
{
throw PromptSyntaxError.UnexpectedEndOfText(
currentToken.StartIndex,
GetSafeEndIndex(currentToken.EndIndex)
);
}
}
// For embeddings, we add to the prompt, and not to the extra networks list
if (parsedNetworkType is PromptExtraNetworkType.Embedding)
{
// Push to output in Comfy format
// -> embedding:model
// -> (embedding:model:weight)
outputTokens.Push(currentToken);
outputText.Push(
weight is null
? $"embedding:{modelName}"
: $"(embedding:{modelName}:{weight:F2})"
);
}
// Cleanups for separate extra networks
else
{
// If last entry on stack is a separator, remove it
if (
outputTokens.TryPeek(out var lastToken2)
&& lastToken2.Scopes.Contains("punctuation.separator.variable.prompt")
)
{
outputTokens.Pop();
outputText.Pop();
}
// Add to output
promptExtraNetworks.Add(
new PromptExtraNetwork
{
Type = parsedNetworkType,
Name = modelName,
ModelWeight = weight
}
);
}
}
var processedText = string.Join("", outputText.Reverse());
return (promptExtraNetworks, processedText);
}
public string GetDebugText()
{
var builder = new StringBuilder();
foreach (var token in TokenizeResult.Tokens)
{
// Get token text
var text = RawText[token.StartIndex..Math.Min(token.EndIndex, RawText.Length - 1)];
// Format scope
var scopeStr = string.Join(
", ",
token.Scopes
.Where(s => s != "source.prompt")
.Select(
s =>
s.EndsWith(".prompt")
? s.Remove(s.LastIndexOf(".prompt", StringComparison.Ordinal))
: s
)
);
builder.AppendLine($"{text.ToRepr()} ({token.StartIndex}, {token.EndIndex})");
builder.AppendLine($" └─ {scopeStr}");
}
return builder.ToString();
}
public static Prompt FromRawText(string text, ITokenizerProvider tokenizer)
{
using var _ = new CodeTimer();
var result = tokenizer.TokenizeLine(text);
return new Prompt { RawText = text, TokenizeResult = result };
}
}