using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading.Tasks; using AvaloniaEdit; using AvaloniaEdit.Document; using AvaloniaEdit.Editing; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(PromptCard))] public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState { private readonly IModelIndexService modelIndexService; /// /// Cache of prompt text to tokenized Prompt /// private LRUCache PromptCache { get; } = new(4); public ICompletionProvider CompletionProvider { get; } public ITokenizerProvider TokenizerProvider { get; } public SharedState SharedState { get; } public TextDocument PromptDocument { get; } = new(); public TextDocument NegativePromptDocument { get; } = new(); [ObservableProperty] private bool isAutoCompletionEnabled; /// public PromptCardViewModel( ICompletionProvider completionProvider, ITokenizerProvider tokenizerProvider, ISettingsManager settingsManager, IModelIndexService modelIndexService, SharedState sharedState ) { this.modelIndexService = modelIndexService; CompletionProvider = completionProvider; TokenizerProvider = tokenizerProvider; SharedState = sharedState; settingsManager.RelayPropertyFor( this, vm => vm.IsAutoCompletionEnabled, settings => settings.IsPromptCompletionEnabled, true ); } /// /// Gets the tokenized Prompt for given text and caches it /// private Prompt GetOrCachePrompt(string text) { // Try get from cache if (PromptCache.Get(text, out var cachedPrompt)) { return cachedPrompt!; } var prompt = Prompt.FromRawText(text, TokenizerProvider); PromptCache.Add(text, prompt); return prompt; } /// /// Processes current positive prompt text into a Prompt object /// public Prompt GetPrompt() => GetOrCachePrompt(PromptDocument.Text); /// /// Processes current negative prompt text into a Prompt object /// public Prompt GetNegativePrompt() => GetOrCachePrompt(NegativePromptDocument.Text); /// /// Validates both prompts, shows an error dialog if invalid /// public async Task ValidatePrompts() { var promptText = PromptDocument.Text; var negPromptText = NegativePromptDocument.Text; try { var prompt = GetOrCachePrompt(promptText); prompt.Process(); prompt.ValidateExtraNetworks(modelIndexService); } catch (PromptError e) { var dialog = DialogHelper.CreatePromptErrorDialog(e, promptText, modelIndexService); await dialog.ShowAsync(); return false; } try { var negPrompt = GetOrCachePrompt(negPromptText); negPrompt.Process(); } catch (PromptError e) { var dialog = DialogHelper.CreatePromptErrorDialog(e, negPromptText, modelIndexService); await dialog.ShowAsync(); return false; } return true; } [RelayCommand] private async Task ShowHelpDialog() { var md = $""" ## {Resources.Label_Emphasis} ```prompt (keyword) (keyword:1.0) ``` ## {Resources.Label_Deemphasis} ```prompt [keyword] ``` ## {Resources.Label_EmbeddingsOrTextualInversion} They may be used in either the positive or negative prompts. Essentially they are text presets, so the position where you place them could make a difference. ```prompt ``` ## {Resources.Label_NetworksLoraOrLycoris} Unlike embeddings, network tags do not get tokenized to the model, so the position in the prompt where you place them does not matter. ```prompt ``` ## {Resources.Label_Comments} Inline comments can be marked by a hashtag ‘#’. All text after a ‘#’ on a line will be disregarded during generation. ```prompt # comments a red cat # also comments detailed ``` """; var dialog = DialogHelper.CreateMarkdownDialog( md, "Prompt Syntax", TextEditorPreset.Prompt ); dialog.MinDialogWidth = 800; dialog.MaxDialogHeight = 1000; dialog.MaxDialogWidth = 1000; await dialog.ShowAsync(); } [RelayCommand] private async Task DebugShowTokens() { var prompt = GetPrompt(); try { prompt.Process(); } catch (PromptError e) { await DialogHelper .CreatePromptErrorDialog(e, prompt.RawText, modelIndexService) .ShowAsync(); return; } var tokens = prompt.TokenizeResult.Tokens; var builder = new StringBuilder(); builder.AppendLine($"## Tokens ({tokens.Length}):"); builder.AppendLine("```csharp"); builder.AppendLine(prompt.GetDebugText()); builder.AppendLine("```"); try { if (prompt.ExtraNetworks is { } networks) { builder.AppendLine($"## Networks ({networks.Count}):"); builder.AppendLine("```csharp"); builder.AppendLine( JsonSerializer.Serialize( networks, new JsonSerializerOptions() { WriteIndented = true, } ) ); builder.AppendLine("```"); } builder.AppendLine("## Formatted for server:"); builder.AppendLine("```csharp"); builder.AppendLine(prompt.ProcessedText); builder.AppendLine("```"); } catch (PromptError e) { builder.AppendLine($"##{e.GetType().Name} - {e.Message}"); builder.AppendLine("```csharp"); builder.AppendLine(e.StackTrace); builder.AppendLine("```"); throw; } var dialog = DialogHelper.CreateMarkdownDialog(builder.ToString(), "Prompt Tokens"); dialog.MinDialogWidth = 800; dialog.MaxDialogHeight = 1000; dialog.MaxDialogWidth = 1000; await dialog.ShowAsync(); } [RelayCommand] private void EditorCopy(TextEditor? textEditor) { textEditor?.Copy(); } [RelayCommand] private void EditorPaste(TextEditor? textEditor) { textEditor?.Paste(); } [RelayCommand] private void EditorCut(TextEditor? textEditor) { textEditor?.Cut(); } /// public override JsonObject SaveStateToJsonObject() { return SerializeModel( new PromptCardModel { Prompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text } ); } /// public override void LoadStateFromJsonObject(JsonObject state) { var model = DeserializeModel(state); PromptDocument.Text = model.Prompt ?? ""; NegativePromptDocument.Text = model.NegativePrompt ?? ""; } /// public void LoadStateFromParameters(GenerationParameters parameters) { PromptDocument.Text = parameters.PositivePrompt ?? ""; NegativePromptDocument.Text = parameters.NegativePrompt ?? ""; } /// public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { return parameters with { PositivePrompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text }; } }