using System.Linq; using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading.Tasks; using AvaloniaEdit; using AvaloniaEdit.Document; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(PromptCard))] [ManagedService] [Transient] public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState, IComfyStep { private readonly IModelIndexService modelIndexService; /// /// Cache of prompt text to tokenized Prompt /// private LRUCache PromptCache { get; } = new(4); public ICompletionProvider CompletionProvider { get; } public ITokenizerProvider TokenizerProvider { get; } public SharedState SharedState { get; } public TextDocument PromptDocument { get; } = new(); public TextDocument NegativePromptDocument { get; } = new(); public StackEditableCardViewModel ModulesCardViewModel { get; } [ObservableProperty] private bool isAutoCompletionEnabled; /// public PromptCardViewModel( ICompletionProvider completionProvider, ITokenizerProvider tokenizerProvider, ISettingsManager settingsManager, IModelIndexService modelIndexService, ServiceManager vmFactory, SharedState sharedState ) { this.modelIndexService = modelIndexService; CompletionProvider = completionProvider; TokenizerProvider = tokenizerProvider; SharedState = sharedState; ModulesCardViewModel = vmFactory.Get(vm => { vm.Title = "Styles"; vm.AvailableModules = [typeof(PromptExpansionModule)]; }); settingsManager.RelayPropertyFor( this, vm => vm.IsAutoCompletionEnabled, settings => settings.IsPromptCompletionEnabled, true ); } /// /// Applies the prompt step. /// Requires: /// /// - Model, Clip /// /// Provides: /// /// - Conditioning /// /// public void ApplyStep(ModuleApplyStepEventArgs e) { // Load prompts var positivePrompt = GetPrompt(); positivePrompt.Process(); e.Builder.Connections.PositivePrompt = positivePrompt.ProcessedText; var negativePrompt = GetNegativePrompt(); negativePrompt.Process(); e.Builder.Connections.NegativePrompt = negativePrompt.ProcessedText; // Apply modules / styles that may modify the prompt ModulesCardViewModel.ApplyStep(e); foreach (var modelConnections in e.Builder.Connections.Models.Values) { if (modelConnections.Model is not { } model || modelConnections.Clip is not { } clip) continue; // If need to load loras, add a group if (positivePrompt.ExtraNetworks.Count > 0) { var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList(); // Add group to load loras onto model and clip in series var lorasGroup = e.Builder.Group_LoraLoadMany( $"Loras_{modelConnections.Name}", model, clip, loras ); // Set last outputs as model and clip modelConnections.Model = lorasGroup.Output1; modelConnections.Clip = lorasGroup.Output2; } // Clips var positiveClip = e.Nodes.AddTypedNode( new ComfyNodeBuilder.CLIPTextEncode { Name = $"PositiveCLIP_{modelConnections.Name}", Clip = e.Builder.Connections.Base.Clip!, Text = e.Builder.Connections.PositivePrompt } ); var negativeClip = e.Nodes.AddTypedNode( new ComfyNodeBuilder.CLIPTextEncode { Name = $"NegativeCLIP_{modelConnections.Name}", Clip = e.Builder.Connections.Base.Clip!, Text = e.Builder.Connections.NegativePrompt } ); // Set conditioning from Clips modelConnections.Conditioning = (positiveClip.Output, negativeClip.Output); } } /// /// Gets the tokenized Prompt for given text and caches it /// private Prompt GetOrCachePrompt(string text) { // Try get from cache if (PromptCache.Get(text, out var cachedPrompt)) { return cachedPrompt!; } var prompt = Prompt.FromRawText(text, TokenizerProvider); PromptCache.Add(text, prompt); return prompt; } /// /// Processes current positive prompt text into a Prompt object /// public Prompt GetPrompt() => GetOrCachePrompt(PromptDocument.Text); /// /// Processes current negative prompt text into a Prompt object /// public Prompt GetNegativePrompt() => GetOrCachePrompt(NegativePromptDocument.Text); /// /// Validates both prompts, shows an error dialog if invalid /// public async Task ValidatePrompts() { var promptText = PromptDocument.Text; var negPromptText = NegativePromptDocument.Text; try { var prompt = GetOrCachePrompt(promptText); prompt.Process(); prompt.ValidateExtraNetworks(modelIndexService); } catch (PromptError e) { var dialog = DialogHelper.CreatePromptErrorDialog(e, promptText, modelIndexService); await dialog.ShowAsync(); return false; } try { var negPrompt = GetOrCachePrompt(negPromptText); negPrompt.Process(); } catch (PromptError e) { var dialog = DialogHelper.CreatePromptErrorDialog(e, negPromptText, modelIndexService); await dialog.ShowAsync(); return false; } return true; } [RelayCommand] private async Task ShowHelpDialog() { var md = $""" ## {Resources.Label_Emphasis} ```prompt (keyword) (keyword:1.0) ``` ## {Resources.Label_Deemphasis} ```prompt [keyword] ``` ## {Resources.Label_EmbeddingsOrTextualInversion} They may be used in either the positive or negative prompts. Essentially they are text presets, so the position where you place them could make a difference. ```prompt ``` ## {Resources.Label_NetworksLoraOrLycoris} Unlike embeddings, network tags do not get tokenized to the model, so the position in the prompt where you place them does not matter. ```prompt ``` ## {Resources.Label_Comments} Inline comments can be marked by a hashtag ‘#’. All text after a ‘#’ on a line will be disregarded during generation. ```prompt # comments a red cat # also comments detailed ``` """; var dialog = DialogHelper.CreateMarkdownDialog(md, "Prompt Syntax", TextEditorPreset.Prompt); dialog.MinDialogWidth = 800; dialog.MaxDialogHeight = 1000; dialog.MaxDialogWidth = 1000; await dialog.ShowAsync(); } [RelayCommand] private async Task DebugShowTokens() { var prompt = GetPrompt(); try { prompt.Process(); } catch (PromptError e) { await DialogHelper.CreatePromptErrorDialog(e, prompt.RawText, modelIndexService).ShowAsync(); return; } var tokens = prompt.TokenizeResult.Tokens; var builder = new StringBuilder(); builder.AppendLine($"## Tokens ({tokens.Length}):"); builder.AppendLine("```csharp"); builder.AppendLine(prompt.GetDebugText()); builder.AppendLine("```"); try { if (prompt.ExtraNetworks is { } networks) { builder.AppendLine($"## Networks ({networks.Count}):"); builder.AppendLine("```csharp"); builder.AppendLine( JsonSerializer.Serialize(networks, new JsonSerializerOptions() { WriteIndented = true, }) ); builder.AppendLine("```"); } builder.AppendLine("## Formatted for server:"); builder.AppendLine("```csharp"); builder.AppendLine(prompt.ProcessedText); builder.AppendLine("```"); } catch (PromptError e) { builder.AppendLine($"##{e.GetType().Name} - {e.Message}"); builder.AppendLine("```csharp"); builder.AppendLine(e.StackTrace); builder.AppendLine("```"); throw; } var dialog = DialogHelper.CreateMarkdownDialog(builder.ToString(), "Prompt Tokens"); dialog.MinDialogWidth = 800; dialog.MaxDialogHeight = 1000; dialog.MaxDialogWidth = 1000; await dialog.ShowAsync(); } [RelayCommand] private void EditorCopy(TextEditor? textEditor) { textEditor?.Copy(); } [RelayCommand] private void EditorPaste(TextEditor? textEditor) { textEditor?.Paste(); } [RelayCommand] private void EditorCut(TextEditor? textEditor) { textEditor?.Cut(); } /// public override JsonObject SaveStateToJsonObject() { return SerializeModel( new PromptCardModel { Prompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text, ModulesCardState = ModulesCardViewModel.SaveStateToJsonObject() } ); } /// public override void LoadStateFromJsonObject(JsonObject state) { var model = DeserializeModel(state); PromptDocument.Text = model.Prompt ?? ""; NegativePromptDocument.Text = model.NegativePrompt ?? ""; if (model.ModulesCardState is not null) { ModulesCardViewModel.LoadStateFromJsonObject(model.ModulesCardState); } } /// public void LoadStateFromParameters(GenerationParameters parameters) { PromptDocument.Text = parameters.PositivePrompt ?? ""; NegativePromptDocument.Text = parameters.NegativePrompt ?? ""; } /// public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { return parameters with { PositivePrompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text }; } }