You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
287 lines
8.6 KiB
287 lines
8.6 KiB
using System; |
|
using System.Collections.Generic; |
|
using System.Collections.Immutable; |
|
using System.Diagnostics; |
|
using System.Linq; |
|
using System.Text; |
|
using System.Text.Json; |
|
using System.Text.Json.Nodes; |
|
using System.Threading.Tasks; |
|
using AvaloniaEdit; |
|
using AvaloniaEdit.Document; |
|
using AvaloniaEdit.Editing; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using CommunityToolkit.Mvvm.Input; |
|
using StabilityMatrix.Avalonia.Controls; |
|
using StabilityMatrix.Avalonia.Languages; |
|
using StabilityMatrix.Avalonia.Models; |
|
using StabilityMatrix.Avalonia.Models.Inference; |
|
using StabilityMatrix.Avalonia.Models.TagCompletion; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Exceptions; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Helper.Cache; |
|
using StabilityMatrix.Core.Services; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Inference; |
|
|
|
[View(typeof(PromptCard))] |
|
public partial class PromptCardViewModel : LoadableViewModelBase |
|
{ |
|
private readonly IModelIndexService modelIndexService; |
|
|
|
/// <summary> |
|
/// Cache of prompt text to tokenized Prompt |
|
/// </summary> |
|
private LRUCache<string, Prompt> PromptCache { get; } = new(4); |
|
|
|
public ICompletionProvider CompletionProvider { get; } |
|
public ITokenizerProvider TokenizerProvider { get; } |
|
public SharedState SharedState { get; } |
|
|
|
public TextDocument PromptDocument { get; } = new(); |
|
public TextDocument NegativePromptDocument { get; } = new(); |
|
|
|
[ObservableProperty] |
|
private bool isAutoCompletionEnabled; |
|
|
|
/// <inheritdoc /> |
|
public PromptCardViewModel( |
|
ICompletionProvider completionProvider, |
|
ITokenizerProvider tokenizerProvider, |
|
ISettingsManager settingsManager, |
|
IModelIndexService modelIndexService, |
|
SharedState sharedState |
|
) |
|
{ |
|
this.modelIndexService = modelIndexService; |
|
CompletionProvider = completionProvider; |
|
TokenizerProvider = tokenizerProvider; |
|
SharedState = sharedState; |
|
|
|
settingsManager.RelayPropertyFor( |
|
this, |
|
vm => vm.IsAutoCompletionEnabled, |
|
settings => settings.IsPromptCompletionEnabled, |
|
true |
|
); |
|
} |
|
|
|
/// <summary> |
|
/// Gets the tokenized Prompt for given text and caches it |
|
/// </summary> |
|
private Prompt GetOrCachePrompt(string text) |
|
{ |
|
// Try get from cache |
|
if (PromptCache.Get(text, out var cachedPrompt)) |
|
{ |
|
return cachedPrompt!; |
|
} |
|
var prompt = Prompt.FromRawText(text, TokenizerProvider); |
|
PromptCache.Add(text, prompt); |
|
return prompt; |
|
} |
|
|
|
/// <summary> |
|
/// Processes current positive prompt text into a Prompt object |
|
/// </summary> |
|
public Prompt GetPrompt() => GetOrCachePrompt(PromptDocument.Text); |
|
|
|
/// <summary> |
|
/// Processes current negative prompt text into a Prompt object |
|
/// </summary> |
|
public Prompt GetNegativePrompt() => GetOrCachePrompt(NegativePromptDocument.Text); |
|
|
|
/// <summary> |
|
/// Validates both prompts, shows an error dialog if invalid |
|
/// </summary> |
|
public async Task<bool> ValidatePrompts() |
|
{ |
|
var promptText = PromptDocument.Text; |
|
var negPromptText = NegativePromptDocument.Text; |
|
|
|
try |
|
{ |
|
var prompt = GetOrCachePrompt(promptText); |
|
prompt.Process(); |
|
prompt.ValidateExtraNetworks(modelIndexService); |
|
} |
|
catch (PromptError e) |
|
{ |
|
var dialog = DialogHelper.CreatePromptErrorDialog(e, promptText, modelIndexService); |
|
await dialog.ShowAsync(); |
|
return false; |
|
} |
|
|
|
try |
|
{ |
|
var negPrompt = GetOrCachePrompt(negPromptText); |
|
negPrompt.Process(); |
|
} |
|
catch (PromptError e) |
|
{ |
|
var dialog = DialogHelper.CreatePromptErrorDialog(e, negPromptText, modelIndexService); |
|
await dialog.ShowAsync(); |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
[RelayCommand] |
|
private async Task ShowHelpDialog() |
|
{ |
|
var md = $""" |
|
## {Resources.Label_Emphasis} |
|
```prompt |
|
(keyword) |
|
(keyword:1.0) |
|
``` |
|
|
|
## {Resources.Label_Deemphasis} |
|
```prompt |
|
[keyword] |
|
``` |
|
|
|
## {Resources.Label_EmbeddingsOrTextualInversion} |
|
They may be used in either the positive or negative prompts. |
|
Essentially they are text presets, so the position where you place them |
|
could make a difference. |
|
```prompt |
|
<embedding:model> |
|
<embedding:model:1.0> |
|
``` |
|
|
|
## {Resources.Label_NetworksLoraOrLycoris} |
|
Unlike embeddings, network tags do not get tokenized to the model, |
|
so the position in the prompt where you place them does not matter. |
|
```prompt |
|
<lora:model> |
|
<lora:model:1.0> |
|
<lyco:model> |
|
<lyco:model:1.0> |
|
``` |
|
|
|
## {Resources.Label_Comments} |
|
Inline comments can be marked by a hashtag ‘#’. |
|
All text after a ‘#’ on a line will be disregarded during generation. |
|
```prompt |
|
# comments |
|
a red cat # also comments |
|
detailed |
|
``` |
|
"""; |
|
|
|
var dialog = DialogHelper.CreateMarkdownDialog( |
|
md, |
|
"Prompt Syntax", |
|
TextEditorPreset.Prompt |
|
); |
|
dialog.MinDialogWidth = 800; |
|
dialog.MaxDialogHeight = 1000; |
|
dialog.MaxDialogWidth = 1000; |
|
await dialog.ShowAsync(); |
|
} |
|
|
|
[RelayCommand] |
|
private async Task DebugShowTokens() |
|
{ |
|
var prompt = GetPrompt(); |
|
|
|
try |
|
{ |
|
prompt.Process(); |
|
} |
|
catch (PromptError e) |
|
{ |
|
await DialogHelper |
|
.CreatePromptErrorDialog(e, prompt.RawText, modelIndexService) |
|
.ShowAsync(); |
|
return; |
|
} |
|
|
|
var tokens = prompt.TokenizeResult.Tokens; |
|
|
|
var builder = new StringBuilder(); |
|
|
|
builder.AppendLine($"## Tokens ({tokens.Length}):"); |
|
builder.AppendLine("```csharp"); |
|
builder.AppendLine(prompt.GetDebugText()); |
|
builder.AppendLine("```"); |
|
|
|
try |
|
{ |
|
if (prompt.ExtraNetworks is { } networks) |
|
{ |
|
builder.AppendLine($"## Networks ({networks.Count}):"); |
|
builder.AppendLine("```csharp"); |
|
builder.AppendLine( |
|
JsonSerializer.Serialize( |
|
networks, |
|
new JsonSerializerOptions() { WriteIndented = true, } |
|
) |
|
); |
|
builder.AppendLine("```"); |
|
} |
|
|
|
builder.AppendLine("## Formatted for server:"); |
|
builder.AppendLine("```csharp"); |
|
builder.AppendLine(prompt.ProcessedText); |
|
builder.AppendLine("```"); |
|
} |
|
catch (PromptError e) |
|
{ |
|
builder.AppendLine($"##{e.GetType().Name} - {e.Message}"); |
|
builder.AppendLine("```csharp"); |
|
builder.AppendLine(e.StackTrace); |
|
builder.AppendLine("```"); |
|
throw; |
|
} |
|
|
|
var dialog = DialogHelper.CreateMarkdownDialog(builder.ToString(), "Prompt Tokens"); |
|
dialog.MinDialogWidth = 800; |
|
dialog.MaxDialogHeight = 1000; |
|
dialog.MaxDialogWidth = 1000; |
|
await dialog.ShowAsync(); |
|
} |
|
|
|
[RelayCommand] |
|
private void EditorCopy(TextEditor? textEditor) |
|
{ |
|
textEditor?.Copy(); |
|
} |
|
|
|
[RelayCommand] |
|
private void EditorPaste(TextEditor? textEditor) |
|
{ |
|
textEditor?.Paste(); |
|
} |
|
|
|
[RelayCommand] |
|
private void EditorCut(TextEditor? textEditor) |
|
{ |
|
textEditor?.Cut(); |
|
} |
|
|
|
/// <inheritdoc /> |
|
public override JsonObject SaveStateToJsonObject() |
|
{ |
|
return SerializeModel( |
|
new PromptCardModel |
|
{ |
|
Prompt = PromptDocument.Text, |
|
NegativePrompt = NegativePromptDocument.Text |
|
} |
|
); |
|
} |
|
|
|
/// <inheritdoc /> |
|
public override void LoadStateFromJsonObject(JsonObject state) |
|
{ |
|
var model = DeserializeModel<PromptCardModel>(state); |
|
|
|
PromptDocument.Text = model.Prompt ?? ""; |
|
NegativePromptDocument.Text = model.NegativePrompt ?? ""; |
|
} |
|
}
|
|
|