using System;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using AvaloniaEdit;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.ViewModels;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))]
[ManagedService]
[Transient]
public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState, IComfyStep
{
private readonly IModelIndexService modelIndexService;
///
/// Cache of prompt text to tokenized Prompt
///
private LRUCache PromptCache { get; } = new(4);
public ICompletionProvider CompletionProvider { get; }
public ITokenizerProvider TokenizerProvider { get; }
public SharedState SharedState { get; }
public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new();
public StackEditableCardViewModel StackEditableCardViewModel { get; }
[ObservableProperty]
private bool isAutoCompletionEnabled;
///
public PromptCardViewModel(
ICompletionProvider completionProvider,
ITokenizerProvider tokenizerProvider,
ISettingsManager settingsManager,
IModelIndexService modelIndexService,
ServiceManager vmFactory,
SharedState sharedState
)
{
this.modelIndexService = modelIndexService;
CompletionProvider = completionProvider;
TokenizerProvider = tokenizerProvider;
SharedState = sharedState;
StackEditableCardViewModel = vmFactory.Get(vm =>
{
vm.Title = "Styles";
vm.AvailableModules = [typeof(PromptExpansionModule)];
});
settingsManager.RelayPropertyFor(
this,
vm => vm.IsAutoCompletionEnabled,
settings => settings.IsPromptCompletionEnabled,
true
);
}
///
/// Applies the prompt step.
/// Requires:
///
/// - - Model, Clip
///
/// Provides:
///
/// - - Conditioning
///
///
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Load prompts
var positivePrompt = GetPrompt();
positivePrompt.Process();
var negativePrompt = GetNegativePrompt();
negativePrompt.Process();
foreach (var modelConnections in e.Builder.Connections.Models.Values)
{
if (modelConnections.Model is not { } model || modelConnections.Clip is not { } clip)
continue;
// If need to load loras, add a group
if (positivePrompt.ExtraNetworks.Count > 0)
{
var loras = positivePrompt.GetExtraNetworksAsLocalModels(modelIndexService).ToList();
// Add group to load loras onto model and clip in series
var lorasGroup = e.Builder.Group_LoraLoadMany(
$"Loras_{modelConnections.Name}",
model,
clip,
loras
);
// Set last outputs as model and clip
modelConnections.Model = lorasGroup.Output1;
modelConnections.Clip = lorasGroup.Output2;
}
// Clips
var positiveClip = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = $"PositiveCLIP_{modelConnections.Name}",
Clip = e.Builder.Connections.Base.Clip!,
Text = positivePrompt.ProcessedText
}
);
var negativeClip = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPTextEncode
{
Name = $"NegativeCLIP_{modelConnections.Name}",
Clip = e.Builder.Connections.Base.Clip!,
Text = negativePrompt.ProcessedText
}
);
// Set conditioning from Clips
modelConnections.Conditioning = (positiveClip.Output, negativeClip.Output);
}
}
///
/// Gets the tokenized Prompt for given text and caches it
///
private Prompt GetOrCachePrompt(string text)
{
// Try get from cache
if (PromptCache.Get(text, out var cachedPrompt))
{
return cachedPrompt!;
}
var prompt = Prompt.FromRawText(text, TokenizerProvider);
PromptCache.Add(text, prompt);
return prompt;
}
///
/// Processes current positive prompt text into a Prompt object
///
public Prompt GetPrompt() => GetOrCachePrompt(PromptDocument.Text);
///
/// Processes current negative prompt text into a Prompt object
///
public Prompt GetNegativePrompt() => GetOrCachePrompt(NegativePromptDocument.Text);
///
/// Validates both prompts, shows an error dialog if invalid
///
public async Task ValidatePrompts()
{
var promptText = PromptDocument.Text;
var negPromptText = NegativePromptDocument.Text;
try
{
var prompt = GetOrCachePrompt(promptText);
prompt.Process();
prompt.ValidateExtraNetworks(modelIndexService);
}
catch (PromptError e)
{
var dialog = DialogHelper.CreatePromptErrorDialog(e, promptText, modelIndexService);
await dialog.ShowAsync();
return false;
}
try
{
var negPrompt = GetOrCachePrompt(negPromptText);
negPrompt.Process();
}
catch (PromptError e)
{
var dialog = DialogHelper.CreatePromptErrorDialog(e, negPromptText, modelIndexService);
await dialog.ShowAsync();
return false;
}
return true;
}
[RelayCommand]
private async Task ShowHelpDialog()
{
var md = $"""
## {Resources.Label_Emphasis}
```prompt
(keyword)
(keyword:1.0)
```
## {Resources.Label_Deemphasis}
```prompt
[keyword]
```
## {Resources.Label_EmbeddingsOrTextualInversion}
They may be used in either the positive or negative prompts.
Essentially they are text presets, so the position where you place them
could make a difference.
```prompt
```
## {Resources.Label_NetworksLoraOrLycoris}
Unlike embeddings, network tags do not get tokenized to the model,
so the position in the prompt where you place them does not matter.
```prompt
```
## {Resources.Label_Comments}
Inline comments can be marked by a hashtag ‘#’.
All text after a ‘#’ on a line will be disregarded during generation.
```prompt
# comments
a red cat # also comments
detailed
```
""";
var dialog = DialogHelper.CreateMarkdownDialog(md, "Prompt Syntax", TextEditorPreset.Prompt);
dialog.MinDialogWidth = 800;
dialog.MaxDialogHeight = 1000;
dialog.MaxDialogWidth = 1000;
await dialog.ShowAsync();
}
[RelayCommand]
private async Task DebugShowTokens()
{
var prompt = GetPrompt();
try
{
prompt.Process();
}
catch (PromptError e)
{
await DialogHelper.CreatePromptErrorDialog(e, prompt.RawText, modelIndexService).ShowAsync();
return;
}
var tokens = prompt.TokenizeResult.Tokens;
var builder = new StringBuilder();
builder.AppendLine($"## Tokens ({tokens.Length}):");
builder.AppendLine("```csharp");
builder.AppendLine(prompt.GetDebugText());
builder.AppendLine("```");
try
{
if (prompt.ExtraNetworks is { } networks)
{
builder.AppendLine($"## Networks ({networks.Count}):");
builder.AppendLine("```csharp");
builder.AppendLine(
JsonSerializer.Serialize(networks, new JsonSerializerOptions() { WriteIndented = true, })
);
builder.AppendLine("```");
}
builder.AppendLine("## Formatted for server:");
builder.AppendLine("```csharp");
builder.AppendLine(prompt.ProcessedText);
builder.AppendLine("```");
}
catch (PromptError e)
{
builder.AppendLine($"##{e.GetType().Name} - {e.Message}");
builder.AppendLine("```csharp");
builder.AppendLine(e.StackTrace);
builder.AppendLine("```");
throw;
}
var dialog = DialogHelper.CreateMarkdownDialog(builder.ToString(), "Prompt Tokens");
dialog.MinDialogWidth = 800;
dialog.MaxDialogHeight = 1000;
dialog.MaxDialogWidth = 1000;
await dialog.ShowAsync();
}
[RelayCommand]
private void EditorCopy(TextEditor? textEditor)
{
textEditor?.Copy();
}
[RelayCommand]
private void EditorPaste(TextEditor? textEditor)
{
textEditor?.Paste();
}
[RelayCommand]
private void EditorCut(TextEditor? textEditor)
{
textEditor?.Cut();
}
///
public override JsonObject SaveStateToJsonObject()
{
return SerializeModel(
new PromptCardModel { Prompt = PromptDocument.Text, NegativePrompt = NegativePromptDocument.Text }
);
}
///
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel(state);
PromptDocument.Text = model.Prompt ?? "";
NegativePromptDocument.Text = model.NegativePrompt ?? "";
}
///
public void LoadStateFromParameters(GenerationParameters parameters)
{
PromptDocument.Text = parameters.PositivePrompt ?? "";
NegativePromptDocument.Text = parameters.NegativePrompt ?? "";
}
///
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with
{
PositivePrompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text
};
}
}