|
|
|
using System;
|
|
|
|
using System.Collections.Generic;
|
|
|
|
using System.ComponentModel.DataAnnotations;
|
|
|
|
using System.Diagnostics.CodeAnalysis;
|
|
|
|
using System.IO;
|
|
|
|
using System.Linq;
|
|
|
|
using System.Text;
|
|
|
|
using StabilityMatrix.Avalonia.Models.TagCompletion;
|
|
|
|
using StabilityMatrix.Core.Exceptions;
|
|
|
|
using StabilityMatrix.Core.Extensions;
|
|
|
|
using StabilityMatrix.Core.Helper;
|
|
|
|
using StabilityMatrix.Core.Models;
|
|
|
|
using StabilityMatrix.Core.Models.Database;
|
|
|
|
using StabilityMatrix.Core.Models.Tokens;
|
|
|
|
using StabilityMatrix.Core.Services;
|
|
|
|
using TextMateSharp.Grammars;
|
|
|
|
|
|
|
|
namespace StabilityMatrix.Avalonia.Models.Inference;
|
|
|
|
|
|
|
|
public record Prompt
|
|
|
|
{
|
|
|
|
public required string RawText { get; init; }
|
|
|
|
|
|
|
|
public required ITokenizeLineResult TokenizeResult { get; init; }
|
|
|
|
|
|
|
|
[MemberNotNullWhen(true, nameof(ExtraNetworks), nameof(ProcessedText))]
|
|
|
|
public bool IsProcessed { get; private set; }
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Extra networks specified in prompt.
|
|
|
|
/// </summary>
|
|
|
|
public IReadOnlyList<PromptExtraNetwork>? ExtraNetworks { get; private set; }
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Processed text suitable for sending to inference backend.
|
|
|
|
/// This excludes extra network (i.e. LORA) tokens.
|
|
|
|
/// </summary>
|
|
|
|
public string? ProcessedText { get; private set; }
|
|
|
|
|
|
|
|
[MemberNotNull(nameof(ExtraNetworks), nameof(ProcessedText))]
|
|
|
|
public void Process()
|
|
|
|
{
|
|
|
|
if (IsProcessed)
|
|
|
|
return;
|
|
|
|
|
|
|
|
var (promptExtraNetworks, processedText) = GetExtraNetworks();
|
|
|
|
ExtraNetworks = promptExtraNetworks;
|
|
|
|
ProcessedText = processedText;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Verifies that extra network files exists locally.
|
|
|
|
/// </summary>
|
|
|
|
/// <exception cref="PromptValidationError">Thrown if a filename does not exist</exception>
|
|
|
|
public void ValidateExtraNetworks(IModelIndexService indexService)
|
|
|
|
{
|
|
|
|
GetExtraNetworks(indexService);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Get ExtraNetworks as local model files and weights.
|
|
|
|
/// </summary>
|
|
|
|
public IEnumerable<(
|
|
|
|
LocalModelFile ModelFile,
|
|
|
|
double? ModelWeight,
|
|
|
|
double? ClipWeight
|
|
|
|
)> GetExtraNetworksAsLocalModels(IModelIndexService indexService)
|
|
|
|
{
|
|
|
|
if (ExtraNetworks is null)
|
|
|
|
{
|
|
|
|
throw new InvalidOperationException(
|
|
|
|
"Prompt must be processed before calling GetExtraNetworksAsLocalModels"
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
foreach (var network in ExtraNetworks)
|
|
|
|
{
|
|
|
|
var sharedFolderType = network.Type.ConvertTo<SharedFolderType>();
|
|
|
|
|
|
|
|
if (!indexService.ModelIndex.TryGetValue(sharedFolderType, out var modelList))
|
|
|
|
{
|
|
|
|
throw new ApplicationException($"Model {network.Name} does not exist in index");
|
|
|
|
}
|
|
|
|
|
|
|
|
var localModel = modelList.FirstOrDefault(
|
|
|
|
m => m.FileNameWithoutExtension == network.Name
|
|
|
|
);
|
|
|
|
if (localModel == null)
|
|
|
|
{
|
|
|
|
throw new ApplicationException($"Model {network.Name} does not exist in index");
|
|
|
|
}
|
|
|
|
|
|
|
|
yield return (localModel, network.ModelWeight, network.ClipWeight);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private int GetSafeEndIndex(int index)
|
|
|
|
{
|
|
|
|
return Math.Min(index, RawText.Length);
|
|
|
|
}
|
|
|
|
|
|
|
|
private (List<PromptExtraNetwork> promptExtraNetworks, string processedText) GetExtraNetworks(
|
|
|
|
IModelIndexService? indexService = null
|
|
|
|
)
|
|
|
|
{
|
|
|
|
// Parse tokens "meta.structure.network.prompt"
|
|
|
|
// "<": "punctuation.definition.network.begin.prompt"
|
|
|
|
// (type): "meta.embedded.network.type.prompt"
|
|
|
|
// ":": "punctuation.separator.variable.prompt"
|
|
|
|
// (content): "meta.embedded.network.model.prompt"
|
|
|
|
// ">": "punctuation.definition.network.end.prompt"
|
|
|
|
using var tokens = TokenizeResult.Tokens.Cast<IToken>().GetEnumerator();
|
|
|
|
|
|
|
|
// Store non-network tokens
|
|
|
|
var outputTokens = new Stack<IToken>();
|
|
|
|
var outputText = new Stack<string>();
|
|
|
|
|
|
|
|
// Store extra networks
|
|
|
|
var promptExtraNetworks = new List<PromptExtraNetwork>();
|
|
|
|
|
|
|
|
while (tokens.MoveNext())
|
|
|
|
{
|
|
|
|
var currentToken = tokens.Current;
|
|
|
|
|
|
|
|
// For any invalid syntax, throw
|
|
|
|
if (currentToken.Scopes.Any(s => s.Contains("invalid.illegal")))
|
|
|
|
{
|
|
|
|
// Generic
|
|
|
|
throw new PromptSyntaxError(
|
|
|
|
"Invalid Token",
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Comments - ignore
|
|
|
|
if (currentToken.Scopes.Any(s => s.Contains("comment.line")))
|
|
|
|
{
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Find start of network token, until then just add to output
|
|
|
|
if (!currentToken.Scopes.Contains("punctuation.definition.network.begin.prompt"))
|
|
|
|
{
|
|
|
|
// Normal tags - Push to output
|
|
|
|
outputTokens.Push(currentToken);
|
|
|
|
outputText.Push(
|
|
|
|
RawText[currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)]
|
|
|
|
);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Expect next token to be network type
|
|
|
|
if (!tokens.MoveNext())
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
currentToken = tokens.Current;
|
|
|
|
|
|
|
|
if (!currentToken.Scopes.Contains("meta.embedded.network.type.prompt"))
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.Network_ExpectedType(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
var networkType = RawText[
|
|
|
|
currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
];
|
|
|
|
|
|
|
|
// Match network type
|
|
|
|
var parsedNetworkType = networkType switch
|
|
|
|
{
|
|
|
|
"lora" => PromptExtraNetworkType.Lora,
|
|
|
|
"lyco" => PromptExtraNetworkType.LyCORIS,
|
|
|
|
"embedding" => PromptExtraNetworkType.Embedding,
|
|
|
|
_
|
|
|
|
=> throw PromptValidationError.Network_UnknownType(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
)
|
|
|
|
};
|
|
|
|
|
|
|
|
// Skip colon token
|
|
|
|
if (!tokens.MoveNext())
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
currentToken = tokens.Current;
|
|
|
|
|
|
|
|
// Ensure next token is colon
|
|
|
|
if (!currentToken.Scopes.Contains("punctuation.separator.variable.prompt"))
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.Network_ExpectedSeparator(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get model name
|
|
|
|
if (!tokens.MoveNext())
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
currentToken = tokens.Current;
|
|
|
|
|
|
|
|
if (!currentToken.Scopes.Contains("meta.embedded.network.model.prompt"))
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.Network_ExpectedName(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
var modelName = RawText[
|
|
|
|
currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
];
|
|
|
|
|
|
|
|
// If index service provided, validate model name
|
|
|
|
if (indexService != null)
|
|
|
|
{
|
|
|
|
var localModelList = indexService.ModelIndex.GetOrAdd(
|
|
|
|
parsedNetworkType.ConvertTo<SharedFolderType>()
|
|
|
|
);
|
|
|
|
var localModel = localModelList.FirstOrDefault(
|
|
|
|
m => Path.GetFileNameWithoutExtension(m.FileName) == modelName
|
|
|
|
);
|
|
|
|
if (localModel == null)
|
|
|
|
{
|
|
|
|
throw PromptValidationError.Network_UnknownModel(
|
|
|
|
modelName,
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Skip another colon token
|
|
|
|
if (!tokens.MoveNext())
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
currentToken = tokens.Current;
|
|
|
|
|
|
|
|
double? weight = null;
|
|
|
|
// If its a ending token instead, we can end here, otherwise keep parsing for weight
|
|
|
|
if (!currentToken.Scopes.Contains("punctuation.definition.network.end.prompt"))
|
|
|
|
{
|
|
|
|
// Ensure next token is colon
|
|
|
|
if (!currentToken.Scopes.Contains("punctuation.separator.variable.prompt"))
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.Network_ExpectedSeparator(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get model weight
|
|
|
|
if (!tokens.MoveNext())
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
currentToken = tokens.Current;
|
|
|
|
|
|
|
|
if (!currentToken.Scopes.Contains("constant.numeric"))
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.Network_ExpectedWeight(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
var modelWeight = RawText[
|
|
|
|
currentToken.StartIndex..GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
];
|
|
|
|
|
|
|
|
// Convert to double
|
|
|
|
if (!double.TryParse(modelWeight, out var weightValue))
|
|
|
|
{
|
|
|
|
throw PromptValidationError.Network_InvalidWeight(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
weight = weightValue;
|
|
|
|
|
|
|
|
// Expect end
|
|
|
|
if (!tokens.MoveNext())
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
currentToken = tokens.Current;
|
|
|
|
|
|
|
|
if (!currentToken.Scopes.Contains("punctuation.definition.network.end.prompt"))
|
|
|
|
{
|
|
|
|
throw PromptSyntaxError.UnexpectedEndOfText(
|
|
|
|
currentToken.StartIndex,
|
|
|
|
GetSafeEndIndex(currentToken.EndIndex)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// For embeddings, we add to the prompt, and not to the extra networks list
|
|
|
|
if (parsedNetworkType is PromptExtraNetworkType.Embedding)
|
|
|
|
{
|
|
|
|
// Push to output in Comfy format
|
|
|
|
// <embedding:model> -> embedding:model
|
|
|
|
// <embedding:model:weight> -> (embedding:model:weight)
|
|
|
|
outputTokens.Push(currentToken);
|
|
|
|
|
|
|
|
outputText.Push(
|
|
|
|
weight is null
|
|
|
|
? $"embedding:{modelName}"
|
|
|
|
: $"(embedding:{modelName}:{weight:F2})"
|
|
|
|
);
|
|
|
|
}
|
|
|
|
// Cleanups for separate extra networks
|
|
|
|
else
|
|
|
|
{
|
|
|
|
// If last entry on stack is a separator, remove it
|
|
|
|
if (
|
|
|
|
outputTokens.TryPeek(out var lastToken2)
|
|
|
|
&& lastToken2.Scopes.Contains("punctuation.separator.variable.prompt")
|
|
|
|
)
|
|
|
|
{
|
|
|
|
outputTokens.Pop();
|
|
|
|
outputText.Pop();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add to output
|
|
|
|
promptExtraNetworks.Add(
|
|
|
|
new PromptExtraNetwork
|
|
|
|
{
|
|
|
|
Type = parsedNetworkType,
|
|
|
|
Name = modelName,
|
|
|
|
ModelWeight = weight
|
|
|
|
}
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
var processedText = string.Join("", outputText.Reverse());
|
|
|
|
|
|
|
|
return (promptExtraNetworks, processedText);
|
|
|
|
}
|
|
|
|
|
|
|
|
public string GetDebugText()
|
|
|
|
{
|
|
|
|
var builder = new StringBuilder();
|
|
|
|
|
|
|
|
foreach (var token in TokenizeResult.Tokens)
|
|
|
|
{
|
|
|
|
// Get token text
|
|
|
|
var text = RawText[token.StartIndex..Math.Min(token.EndIndex, RawText.Length - 1)];
|
|
|
|
|
|
|
|
// Format scope
|
|
|
|
var scopeStr = string.Join(
|
|
|
|
", ",
|
|
|
|
token.Scopes
|
|
|
|
.Where(s => s != "source.prompt")
|
|
|
|
.Select(
|
|
|
|
s =>
|
|
|
|
s.EndsWith(".prompt")
|
|
|
|
? s.Remove(s.LastIndexOf(".prompt", StringComparison.Ordinal))
|
|
|
|
: s
|
|
|
|
)
|
|
|
|
);
|
|
|
|
|
|
|
|
builder.AppendLine($"{text.ToRepr()} ({token.StartIndex}, {token.EndIndex})");
|
|
|
|
builder.AppendLine($" └─ {scopeStr}");
|
|
|
|
}
|
|
|
|
|
|
|
|
return builder.ToString();
|
|
|
|
}
|
|
|
|
|
|
|
|
public static Prompt FromRawText(string text, ITokenizerProvider tokenizer)
|
|
|
|
{
|
|
|
|
using var _ = new CodeTimer();
|
|
|
|
|
|
|
|
var result = tokenizer.TokenizeLine(text);
|
|
|
|
|
|
|
|
return new Prompt { RawText = text, TokenizeResult = result };
|
|
|
|
}
|
|
|
|
}
|