Browse Source

Update completion behaviors for extra models syntax

pull/165/head
Ionite 1 year ago
parent
commit
de03999143
No known key found for this signature in database
  1. 64
      StabilityMatrix.Avalonia/Assets/ImagePrompt.tmLanguage.json
  2. 19
      StabilityMatrix.Avalonia/Assets/ThemeMatrixDark.json
  3. 109
      StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs
  4. 244
      StabilityMatrix.Avalonia/Behaviors/TextEditorToolTipBehavior.cs
  5. 13
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionData.cs
  6. 6
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionIcons.cs
  7. 26
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs
  8. 19
      StabilityMatrix.Avalonia/Controls/CodeCompletion/ICompletionData.cs
  9. 111
      StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarker.cs
  10. 262
      StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarkerService.cs
  11. 8
      StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarkerValidationEventArgs.cs
  12. 63
      StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarkerValidatorService.cs
  13. 216
      StabilityMatrix.Avalonia/Models/Inference/Prompt.cs
  14. 13
      StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetwork.cs
  15. 16
      StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetworkType.cs
  16. 6
      StabilityMatrix.Avalonia/Models/Inference/Tokens/TokenNames.cs
  17. 10
      StabilityMatrix.Avalonia/Models/TagCompletion/CompletionType.cs
  18. 8
      StabilityMatrix.Avalonia/Models/TagCompletion/EditorCompletionRequest.cs
  19. 35
      StabilityMatrix.Avalonia/Models/TagCompletion/ModelCompletionData.cs
  20. 12
      StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs
  21. 13
      StabilityMatrix.Core/Exceptions/PromptError.cs
  22. 24
      StabilityMatrix.Core/Exceptions/PromptSyntaxError.cs
  23. 15
      StabilityMatrix.Core/Exceptions/PromptValidationError.cs

64
StabilityMatrix.Avalonia/Assets/ImagePrompt.tmLanguage.json

@ -1,6 +1,9 @@
{
"name": "Image Prompt",
"scopeName": "source.prompt",
"uuid": "A5283894-BA62-4BFE-BB29-7892AE7ACCDC",
"foldingStartMarker": "^.*\b(\\#)\b.*$",
"foldingStopMarker": "(\r?\n){2}",
"patterns": [
{
"include": "#value"
@ -107,16 +110,55 @@
"name": "meta.structure.network.prompt",
"patterns": [
{
"include": "#colon"
"match": "(?<=\\<)([^,:\\<\\> ]+)(:)([^,:\\<\\> ]+)(:)(\\d+(?:\\.\\d+)?)",
"captures": {
"1": {
"name": "meta.embedded.network.type.prompt"
},
"2": {
"name": "punctuation.separator.variable.prompt"
},
"3": {
"name": "meta.embedded.network.model.prompt"
},
"4": {
"name": "punctuation.separator.variable.prompt"
},
"5" : {
"name": "constant.numeric"
}
}
},
{
"include": "#number"
"match": "(?<=\\<)([^,:\\<\\> ]+)(:)([^,:\\<\\> ]+)",
"captures": {
"1": {
"name": "meta.embedded.network.type.prompt"
},
"2": {
"name": "punctuation.separator.variable.prompt"
},
"3": {
"name": "meta.embedded.network.model.prompt"
}
}
},
{
"include": "#text"
"match": "(?<=\\<)([^,:\\<\\> ]+)",
"captures": {
"1": {
"name": "meta.embedded.network.type.prompt"
}
}
},
{
"include": "#colon"
},
{
"include": "#number"
},
{
"match": "[^\\s\\>]",
"match": "[^\\s\\>]+",
"name": "invalid.illegal.expected-array-separator.prompt"
}
]
@ -141,8 +183,12 @@
"match": "\\s+",
"name": "meta.embedded.whitespace"
},
"model": {
"match": "\\b(?<type>[\\w\\d_]+):(?<model>\\w+)(?::(?<weight>\\d+(\\.\\d+)?))?\\b",
"name": "meta.embedded.model"
},
"text": {
"match": "[^,:\\[\\]\\(\\) \\\\]+",
"match": "[^,:\\[\\]\\(\\)\\<\\> \\\\]+",
"name": "meta.embedded"
},
"value": {
@ -173,6 +219,14 @@
},
{
"include": "#text"
},
{
"name": "invalid.illegal.mismatched.parenthesis.closing.prompt",
"match": "\\)"
},
{
"name": "invalid.illegal.mismatched.parenthesis.opening.prompt",
"match": "\\("
}
]
}

19
StabilityMatrix.Avalonia/Assets/ThemeMatrixDark.json

@ -76,6 +76,25 @@
"foreground": "#C5C8C6"
}
},
{
"name": "Network Type",
"scope": [
"meta.embedded.network.type"
],
"settings": {
"fontStyle": "italic",
"foreground": "#3990F6"
}
},
{
"name": "Network Model",
"scope": [
"meta.embedded.network.model"
],
"settings": {
"foreground": "#D0B344"
}
},
{
"name": "Comment",
"scope": "comment",

109
StabilityMatrix.Avalonia/Behaviors/TextEditorCompletionBehavior.cs

@ -2,6 +2,7 @@ using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Avalonia;
using Avalonia.Controls;
using Avalonia.Input;
using Avalonia.Xaml.Interactivity;
using AvaloniaEdit;
@ -9,6 +10,7 @@ using AvaloniaEdit.Document;
using AvaloniaEdit.Editing;
using NLog;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Models.Inference.Tokens;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Core.Extensions;
using TextMateSharp.Grammars;
@ -22,6 +24,10 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
private TextEditor textEditor = null!;
/// <summary>
/// The current completion window, if open.
/// Is set to null when the window is closed.
/// </summary>
private CompletionWindow? completionWindow;
public static readonly StyledProperty<ICompletionProvider?> CompletionProviderProperty =
@ -100,12 +106,14 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
if (completionWindow == null)
{
// Get the segment of the token the caret is currently in
if (GetCaretCompletionToken() is not { } tokenSegment)
if (GetCaretCompletionToken() is not { } completionRequest)
{
Logger.Trace("Token segment not found");
return;
}
var tokenSegment = completionRequest.Segment;
var token = textEditor.Document.GetText(tokenSegment);
Logger.Trace("Using token {Token} ({@Segment})", token, tokenSegment);
@ -113,7 +121,7 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
completionWindow.StartOffset = tokenSegment.Offset;
completionWindow.EndOffset = tokenSegment.EndOffset;
completionWindow.UpdateQuery(token);
completionWindow.UpdateQuery(completionRequest);
completionWindow.Closed += delegate
{
@ -169,14 +177,14 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
private static bool IsCompletionChar(char c)
{
return char.IsLetterOrDigit(c) || c == '_' || c == '-';
return char.IsLetterOrDigit(c) || c == '_' || c == '-' || c == ':';
}
/// <summary>
/// Gets a segment of the token the caret is currently in for completions.
/// Returns null if caret is not on a valid completion token (i.e. comments)
/// </summary>
private ISegment? GetCaretCompletionToken()
private EditorCompletionRequest? GetCaretCompletionToken()
{
var caret = textEditor.CaretOffset;
@ -223,26 +231,91 @@ public class TextEditorCompletionBehavior : Behavior<TextEditor>
var endOffset = currentToken.EndIndex + line.Offset;
// Cap the offsets by the line offsets
return new TextSegment
var segment = new TextSegment
{
StartOffset = Math.Max(startOffset, line.Offset),
EndOffset = Math.Min(endOffset, line.EndOffset)
};
// Search for the start and end of a token
// A token is defined as either alphanumeric chars or a space
/*var start = caret;
while (start > 0 && IsCompletionChar(textEditor.Document.GetCharAt(start - 1)))
// Check if this is an extra network request
if (currentToken.Scopes.Contains("meta.structure.network.prompt")
&& result.Tokens.ElementAtOrDefault(currentTokenIndex - 1) is { } prevToken)
{
start--;
}
// (case for initial '<type:')
// - Current token has "meta.structure.network" and "punctuation.separator.variable"
// - Previous token has "meta.structure.network" and "meta.embedded.network.type"
if (currentToken.Scopes.Contains("punctuation.separator.variable.prompt")
&& prevToken.Scopes.Contains("meta.structure.network.prompt")
&& prevToken.Scopes.Contains("meta.embedded.network.type.prompt"))
{
var networkToken = textEditor.Document.GetText(
prevToken.StartIndex + line.Offset, prevToken.Length);
PromptExtraNetworkType? networkTypeResult = networkToken.ToLowerInvariant() switch
{
"lora" => PromptExtraNetworkType.Lora,
"lyco" => PromptExtraNetworkType.LyCORIS,
"embedding" => PromptExtraNetworkType.Embedding,
_ => null
};
var end = caret;
while (end < textEditor.Document.TextLength && IsCompletionChar(textEditor.Document.GetCharAt(end)))
{
end++;
}
if (networkTypeResult is not { } networkType)
{
return null;
}
return new EditorCompletionRequest
{
Text = "",
Segment = segment,
Type = CompletionType.ExtraNetwork,
ExtraNetworkTypes = networkType,
};
}
// (case for already in model token '<type:network')
// - Current token has "meta.embedded.network.model"
if (currentToken.Scopes.Contains("meta.embedded.network.model.prompt"))
{
var secondPrevToken = result.Tokens.ElementAtOrDefault(currentTokenIndex - 2);
if (secondPrevToken is null)
{
return null;
}
var networkToken = textEditor.Document.GetText(
secondPrevToken.StartIndex + line.Offset, secondPrevToken.Length);
PromptExtraNetworkType? networkTypeResult = networkToken.ToLowerInvariant() switch
{
"lora" => PromptExtraNetworkType.Lora,
"lyco" => PromptExtraNetworkType.LyCORIS,
"embedding" => PromptExtraNetworkType.Embedding,
_ => null
};
return start < end ? new TextSegment { StartOffset = start, EndOffset = end } : null;*/
if (networkTypeResult is not { } networkType)
{
return null;
}
return new EditorCompletionRequest
{
Text = textEditor.Document.GetText(segment),
Segment = segment,
Type = CompletionType.ExtraNetwork,
ExtraNetworkTypes = networkType,
};
}
}
// Otherwise treat as tag
return new EditorCompletionRequest
{
Text = textEditor.Document.GetText(segment),
Segment = segment,
Type = CompletionType.Tag
};
}
}

244
StabilityMatrix.Avalonia/Behaviors/TextEditorToolTipBehavior.cs

@ -0,0 +1,244 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Avalonia;
using Avalonia.Controls;
using Avalonia.Controls.Presenters;
using Avalonia.Input;
using Avalonia.Media;
using Avalonia.Xaml.Interactivity;
using AvaloniaEdit;
using AvaloniaEdit.Document;
using NLog;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Core.Extensions;
using TextMateSharp.Grammars;
namespace StabilityMatrix.Avalonia.Behaviors;
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public class TextEditorToolTipBehavior : Behavior<TextEditor>
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private TextEditor textEditor = null!;
/// <summary>
/// The current ToolTip, if open.
/// Is set to null when the Tooltip is closed.
/// </summary>
private ToolTip? toolTip;
public static readonly StyledProperty<ITokenizerProvider?> TokenizerProviderProperty =
AvaloniaProperty.Register<TextEditorCompletionBehavior, ITokenizerProvider?>(
"TokenizerProvider");
public ITokenizerProvider? TokenizerProvider
{
get => GetValue(TokenizerProviderProperty);
set => SetValue(TokenizerProviderProperty, value);
}
public static readonly StyledProperty<bool> IsEnabledProperty = AvaloniaProperty.Register<TextEditorCompletionBehavior, bool>(
"IsEnabled", true);
public bool IsEnabled
{
get => GetValue(IsEnabledProperty);
set => SetValue(IsEnabledProperty, value);
}
protected override void OnAttached()
{
base.OnAttached();
if (AssociatedObject is not { } editor)
{
throw new NullReferenceException("AssociatedObject is null");
}
textEditor = editor;
textEditor.PointerHover += TextEditor_OnPointerHover;
textEditor.PointerHoverStopped += TextEditor_OnPointerHoverStopped;
}
protected override void OnDetaching()
{
base.OnDetaching();
textEditor.PointerHover -= TextEditor_OnPointerHover;
textEditor.PointerHoverStopped -= TextEditor_OnPointerHoverStopped;
}
/*private void OnVisualLinesChanged(object? sender, EventArgs e)
{
_toolTip?.Close(this);
}*/
private static void TextEditor_OnPointerHoverStopped(object? sender, PointerEventArgs e)
{
if (sender is TextEditor editor)
{
ToolTip.SetIsOpen(editor, false);
e.Handled = true;
}
}
private void TextEditor_OnPointerHover(object? sender, PointerEventArgs e)
{
TextViewPosition? position;
var textArea = textEditor.TextArea;
try
{
position = textArea.TextView.GetPositionFloor(
e.GetPosition(textArea.TextView) + textArea.TextView.ScrollOffset);
}
catch (ArgumentOutOfRangeException)
{
// TODO: check why this happens
e.Handled = true;
return;
}
if (!position.HasValue || position.Value.Location.IsEmpty || position.Value.IsAtEndOfLine)
{
return;
}
/*var args = new ToolTipRequestEventArgs { InDocument = position.HasValue };
args.LogicalPosition = position.Value.Location;
args.Position = textEditor.Document.GetOffset(position.Value.Line, position.Value.Column);*/
// Get the ToolTip data
if (GetCaretToolTipData(position.Value) is not { } data)
{
return;
}
if (toolTip == null)
{
toolTip = new ToolTip
{
MaxWidth = 400
};
ToolTip.SetShowDelay(textEditor, 0);
ToolTip.SetPlacement(textEditor, PlacementMode.Pointer);
ToolTip.SetTip(textEditor, toolTip);
toolTip.GetPropertyChangedObservable(ToolTip.IsOpenProperty).Subscribe(c =>
{
if (c.NewValue as bool? != true)
{
toolTip = null;
}
});
}
toolTip.Content = new TextBlock
{
Text = data.Message,
TextWrapping = TextWrapping.Wrap
};
e.Handled = true;
ToolTip.SetIsOpen(textEditor, true);
toolTip.InvalidateVisual();
}
/// <summary>
/// Get ToolTip data to show at the caret position, can be null if no ToolTip should be shown.
/// </summary>
private ToolTipData? GetCaretToolTipData(TextViewPosition position)
{
var logicalPosition = position.Location;
var pointerOffset = textEditor.Document.GetOffset(logicalPosition.Line, logicalPosition.Column);
var line = textEditor.Document.GetLineByOffset(pointerOffset);
var lineText = textEditor.Document.GetText(line.Offset, line.Length);
var lineOffset = pointerOffset - line.Offset;
/*var caret = textEditor.CaretOffset;
// Get the line the caret is on
var line = textEditor.Document.GetLineByOffset(caret);
var lineText = textEditor.Document.GetText(line.Offset, line.Length);
var caretAbsoluteOffset = caret - line.Offset;*/
// Tokenize
var result = TokenizerProvider!.TokenizeLine(lineText);
var currentTokenIndex = -1;
IToken? currentToken = null;
// Get the token the caret is after
foreach (var (i, token) in result.Tokens.Enumerate())
{
// If we see a line comment token anywhere, return null
var isComment = token.Scopes.Any(s => s.Contains("comment.line"));
if (isComment)
{
Logger.Trace("Caret is in a comment");
return null;
}
// Find match
if (lineOffset >= token.StartIndex && lineOffset <= token.EndIndex)
{
currentTokenIndex = i;
currentToken = token;
break;
}
}
// Still not found
if (currentToken is null || currentTokenIndex == -1)
{
Logger.Info($"Could not find token at pointer offset {pointerOffset} for line {lineText.ToRepr()}");
return null;
}
var startOffset = currentToken.StartIndex + line.Offset;
var endOffset = currentToken.EndIndex + line.Offset;
// Cap the offsets by the line offsets
var segment = new TextSegment
{
StartOffset = Math.Max(startOffset, line.Offset),
EndOffset = Math.Min(endOffset, line.EndOffset)
};
// Only return for supported scopes
// Attempt with first current, then next and previous
foreach (var tokenOffset in new[] { 0, 1, -1 })
{
if (result.Tokens.ElementAtOrDefault(currentTokenIndex + tokenOffset) is { } token)
{
// Check supported scopes
if (token.Scopes.Where(s => s.Contains("invalid")).ToArray()
is { Length: > 0 } results)
{
// Special cases
if (results.Contains("invalid.illegal.mismatched.parenthesis.closing.prompt"))
{
return new ToolTipData(segment, "Mismatched closing parenthesis ')'");
}
if (results.Contains("invalid.illegal.mismatched.parenthesis.opening.prompt"))
{
return new ToolTipData(segment, "Mismatched opening parenthesis '('");
}
return new ToolTipData(segment, "Syntax error: " + string.Join(", ", results));
}
}
}
return null;
}
internal record ToolTipData(ISegment Segment, string Message);
}

13
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionData.cs

@ -1,5 +1,6 @@
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Avalonia.Controls;
using Avalonia.Controls.Documents;
using Avalonia.Media;
@ -23,8 +24,18 @@ public class CompletionData : ICompletionData
public string? Description { get; init; }
/// <inheritdoc />
public IImage? Image { get; set; }
public ImageSource? ImageSource { get; set; }
/// <summary>
/// Title of the image.
/// </summary>
public string? ImageTitle { get; set; }
/// <summary>
/// Subtitle of the image.
/// </summary>
public string? ImageSubtitle { get; set; }
/// <inheritdoc />
public IconData? Icon { get; init; }

6
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionIcons.cs

@ -52,6 +52,12 @@ public static class CompletionIcons
FAIcon = "fa-solid fa-key",
Foreground = ThemeColors.CompletionForegroundBrush,
};
public static readonly IconData Model = new()
{
FAIcon = "fa-solid fa-cube",
Foreground = ThemeColors.CompletionForegroundBrush,
};
public static IconData? GetIconForTagType(TagType tagType)
{

26
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionWindow.axaml.cs

@ -274,9 +274,14 @@ public class CompletionWindow : CompletionWindowBase
{
var newText = document.GetText(StartOffset, offset - StartOffset);
Debug.WriteLine("CaretPositionChanged newText: " + newText);
if (lastSearchRequest is not { } lastRequest)
{
return;
}
// CompletionList.SelectItem(newText);
Dispatcher.UIThread.Post(() => UpdateQuery(newText));
Dispatcher.UIThread.Post(() => UpdateQuery(lastRequest with { Text = newText }));
// UpdateQuery(newText);
IsVisible = CompletionList.ListBox!.ItemCount != 0;
@ -284,33 +289,36 @@ public class CompletionWindow : CompletionWindowBase
}
}
private string? lastSearchTerm;
private TextCompletionRequest? lastSearchRequest;
private int lastCompletionLength;
/// <summary>
/// Update the completion window's current search term.
/// </summary>
public void UpdateQuery(string searchTerm)
public void UpdateQuery(TextCompletionRequest completionRequest)
{
var searchTerm = completionRequest.Text;
// Fast path if the search term starts with the last search term
// and the last completion count was less than the max list length
// (such we won't get new results by searching again)
if (lastSearchTerm is not null
&& searchTerm.StartsWith(lastSearchTerm)
if (lastSearchRequest is not null
&& completionRequest.Type == lastSearchRequest.Type
&& searchTerm.StartsWith(lastSearchRequest.Text)
&& lastCompletionLength < MaxListLength)
{
CompletionList.SelectItem(searchTerm);
lastSearchTerm = searchTerm;
lastSearchRequest = completionRequest;
return;
}
var results = completionProvider.GetCompletions(searchTerm, MaxListLength, true);
var results = completionProvider.GetCompletions(completionRequest, MaxListLength, true);
CompletionList.CompletionData.Clear();
CompletionList.CompletionData.AddRange(results);
CompletionList.SelectItem(searchTerm, true);
lastSearchTerm = searchTerm;
lastSearchRequest = completionRequest;
lastCompletionLength = CompletionList.CompletionData.Count;
}
}

19
StabilityMatrix.Avalonia/Controls/CodeCompletion/ICompletionData.cs

@ -17,6 +17,7 @@
// DEALINGS IN THE SOFTWARE.
using System;
using System.Diagnostics.CodeAnalysis;
using Avalonia.Controls.Documents;
using Avalonia.Media;
using AvaloniaEdit.Document;
@ -48,7 +49,23 @@ public interface ICompletionData
/// <summary>
/// Gets the image.
/// </summary>
IImage? Image { get; }
ImageSource? ImageSource { get; }
/// <summary>
/// Title of the image.
/// </summary>
string? ImageTitle { get; }
/// <summary>
/// Subtitle of the image.
/// </summary>
string? ImageSubtitle { get; }
/// <summary>
/// Whether the image is available.
/// </summary>
[MemberNotNullWhen(true, nameof(ImageSource))]
bool HasImage => ImageSource != null;
/// <summary>
/// Gets the icon shown on the left.

111
StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarker.cs

@ -0,0 +1,111 @@
using System;
using System.Collections.Generic;
using Avalonia.Media;
using AvaloniaEdit.Document;
namespace StabilityMatrix.Avalonia.Controls.TextMarkers;
public sealed class TextMarker : TextSegment
{
private readonly TextMarkerService _service;
public TextMarker(TextMarkerService service, int startOffset, int length)
{
_service = service ?? throw new ArgumentNullException(nameof(service));
StartOffset = startOffset;
Length = length;
}
public event EventHandler? Deleted;
public bool IsDeleted => !IsConnectedToCollection;
public void Delete()
{
_service.Remove(this);
}
internal void OnDeleted()
{
Deleted?.Invoke(this, EventArgs.Empty);
}
private void Redraw()
{
_service.Redraw(this);
}
private Color? _backgroundColor;
public Color? BackgroundColor
{
get => _backgroundColor; set
{
if (!EqualityComparer<Color?>.Default.Equals(_backgroundColor, value))
{
_backgroundColor = value;
Redraw();
}
}
}
private Color? _foregroundColor;
public Color? ForegroundColor
{
get => _foregroundColor; set
{
if (!EqualityComparer<Color?>.Default.Equals(_foregroundColor, value))
{
_foregroundColor = value;
Redraw();
}
}
}
private FontWeight? _fontWeight;
public FontWeight? FontWeight
{
get => _fontWeight; set
{
if (_fontWeight != value)
{
_fontWeight = value;
Redraw();
}
}
}
private FontStyle? _fontStyle;
public FontStyle? FontStyle
{
get => _fontStyle; set
{
if (_fontStyle != value)
{
_fontStyle = value;
Redraw();
}
}
}
public object? Tag { get; set; }
private Color _markerColor;
public Color MarkerColor
{
get => _markerColor; set
{
if (!EqualityComparer<Color>.Default.Equals(_markerColor, value))
{
_markerColor = value;
Redraw();
}
}
}
public object? ToolTip { get; set; }
}

262
StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarkerService.cs

@ -0,0 +1,262 @@
// Copyright (c) 2014 AlphaSierraPapa for the SharpDevelop Team
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this
// software and associated documentation files (the "Software"), to deal in the Software
// without restriction, including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
// to whom the Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Avalonia;
using Avalonia.Media;
using AvaloniaEdit;
using AvaloniaEdit.Document;
using AvaloniaEdit.Rendering;
using CommonBrush = Avalonia.Media.IBrush;
namespace StabilityMatrix.Avalonia.Controls.TextMarkers;
public sealed class TextMarkerService : DocumentColorizingTransformer, IBackgroundRenderer, ITextViewConnect
{
private readonly TextSegmentCollection<TextMarker> _markers;
private readonly TextDocument _document;
private readonly List<TextView> _textViews;
public TextMarkerService(TextEditor editor)
{
if (editor == null) throw new ArgumentNullException(nameof(editor));
_document = editor.Document;
_markers = new TextSegmentCollection<TextMarker>(_document);
_textViews = new List<TextView>();
// editor.ToolTipRequest += EditorOnToolTipRequest;
}
/*private void EditorOnToolTipRequest(object? sender, ToolTipRequestEventArgs args)
{
var offset = _document.GetOffset(args.LogicalPosition);
//FoldingManager foldings = _editor.GetService(typeof(FoldingManager)) as FoldingManager;
//if (foldings != null)
//{
// var foldingsAtOffset = foldings.GetFoldingsAt(offset);
// FoldingSection collapsedSection = foldingsAtOffset.FirstOrDefault(section => section.IsFolded);
// if (collapsedSection != null)
// {
// args.SetToolTip(GetTooltipTextForCollapsedSection(args, collapsedSection));
// }
//}
var markersAtOffset = GetMarkersAtOffset(offset);
var markerWithToolTip = markersAtOffset.FirstOrDefault(marker => marker.ToolTip != null);
if (markerWithToolTip != null && markerWithToolTip.ToolTip != null)
{
args.SetToolTip(markerWithToolTip.ToolTip);
}
}*/
#region TextMarkerService
public TextMarker? TryCreate(int startOffset, int length)
{
if (_markers == null)
throw new InvalidOperationException("Cannot create a marker when not attached to a document");
var textLength = _document.TextLength;
if (startOffset < 0 || startOffset > textLength) return null;
//throw new ArgumentOutOfRangeException(nameof(startOffset), startOffset, "Value must be between 0 and " + textLength);
if (length < 0 || startOffset + length > textLength) return null;
//throw new ArgumentOutOfRangeException(nameof(length), length, "length must not be negative and startOffset+length must not be after the end of the document");
var marker = new TextMarker(this, startOffset, length);
_markers.Add(marker);
return marker;
}
public IEnumerable<TextMarker> GetMarkersAtOffset(int offset)
{
return _markers.FindSegmentsContaining(offset);
}
public IEnumerable<TextMarker> TextMarkers => _markers ?? Enumerable.Empty<TextMarker>();
public void RemoveAll(Predicate<TextMarker> predicate)
{
if (predicate == null)
throw new ArgumentNullException(nameof(predicate));
foreach (var m in _markers.ToArray())
{
if (predicate(m))
Remove(m);
}
}
public void Remove(TextMarker? marker)
{
if (marker == null) throw new ArgumentNullException(nameof(marker));
if (_markers.Remove(marker))
{
Redraw(marker);
marker.OnDeleted();
}
}
internal void Redraw(ISegment segment)
{
foreach (var view in _textViews)
{
view.Redraw(segment);
}
RedrawRequested?.Invoke(this, EventArgs.Empty);
}
public event EventHandler? RedrawRequested;
#endregion
#region DocumentColorizingTransformer
protected override void ColorizeLine(DocumentLine line)
{
var lineStart = line.Offset;
var lineEnd = lineStart + line.Length;
foreach (var marker in _markers.FindOverlappingSegments(lineStart, line.Length))
{
CommonBrush? foregroundBrush = null;
if (marker.ForegroundColor != null)
{
foregroundBrush = new SolidColorBrush(marker.ForegroundColor.Value).ToImmutable();
}
ChangeLinePart(
Math.Max(marker.StartOffset, lineStart),
Math.Min(marker.EndOffset, lineEnd),
element =>
{
if (foregroundBrush != null)
{
element.TextRunProperties.SetForegroundBrush(foregroundBrush);
}
var tf = element.TextRunProperties.Typeface;
element.TextRunProperties.SetTypeface(new Typeface(
tf.FontFamily,
marker.FontStyle ?? tf.Style,
marker.FontWeight ?? tf.Weight,
tf.Stretch
));
}
);
}
}
#endregion
#region IBackgroundRenderer
public KnownLayer Layer => KnownLayer.Selection;
public void Draw(TextView textView, DrawingContext drawingContext)
{
if (textView == null)
throw new ArgumentNullException(nameof(textView));
if (drawingContext == null)
throw new ArgumentNullException(nameof(drawingContext));
if (!textView.VisualLinesValid)
return;
var visualLines = textView.VisualLines;
if (visualLines.Count == 0)
return;
var viewStart = visualLines.First().FirstDocumentLine.Offset;
var viewEnd = visualLines.Last().LastDocumentLine.EndOffset;
foreach (var marker in _markers.FindOverlappingSegments(viewStart, viewEnd - viewStart))
{
if (marker.BackgroundColor != null)
{
var geoBuilder = new BackgroundGeometryBuilder
{
AlignToWholePixels = true,
CornerRadius = 3
};
geoBuilder.AddSegment(textView, marker);
var geometry = geoBuilder.CreateGeometry();
if (geometry != null)
{
var color = marker.BackgroundColor.Value;
var brush = new SolidColorBrush(color).ToImmutable();
drawingContext.DrawGeometry(brush, null, geometry);
}
}
foreach (var r in BackgroundGeometryBuilder.GetRectsForSegment(textView, marker))
{
var startPoint = r.BottomLeft;
var endPoint = r.BottomRight;
var usedBrush = new SolidColorBrush(marker.MarkerColor).ToImmutable();
var offset = 2.5;
var count = Math.Max((int)((endPoint.X - startPoint.X) / offset) + 1, 4);
/*var geometry = new StreamGeometry();
using (var ctx = geometry.Open())
{
ctx.BeginFigure(startPoint, false);
// ctx.PolyLineTo(CreatePoints(startPoint, offset, count).ToArray(), true, false);
ctx.LineTo(CreatePoints(startPoint, offset, count).ToArray());
}*/
var geometry = new PolylineGeometry(CreatePoints(startPoint, offset, count), false);
// geometry.Freeze();
var usedPen = new Pen(usedBrush, 1);
// usedPen.Freeze();
drawingContext.DrawGeometry(Brushes.Transparent, usedPen, geometry);
}
}
}
private static IEnumerable<Point> CreatePoints(Point start, double offset, int count)
{
for (var i = 0; i < count; i++)
yield return new Point(start.X + i * offset, start.Y - ((i + 1) % 2 == 0 ? offset : 0));
}
#endregion
#region ITextViewConnect
void ITextViewConnect.AddToTextView(TextView textView)
{
if (textView != null && !_textViews.Contains(textView))
{
Debug.Assert(textView.Document == _document);
_textViews.Add(textView);
}
}
void ITextViewConnect.RemoveFromTextView(TextView textView)
{
if (textView != null)
{
Debug.Assert(textView.Document == _document);
_textViews.Remove(textView);
}
}
#endregion
}

8
StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarkerValidationEventArgs.cs

@ -0,0 +1,8 @@
using System;
namespace StabilityMatrix.Avalonia.Controls.TextMarkers;
public class TextMarkerValidationEventArgs : EventArgs
{
}

63
StabilityMatrix.Avalonia/Controls/TextMarkers/TextMarkerValidatorService.cs

@ -0,0 +1,63 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
namespace StabilityMatrix.Avalonia.Controls.TextMarkers;
public class TextMarkerValidatorService
{
private string? currentText;
private Task? currentTask;
private TimeSpan updateInterval;
public EventHandler<TextMarkerValidationEventArgs>? ValidationUpdate;
private void OnValidationUpdate(TextMarkerValidationEventArgs e)
{
ValidationUpdate?.Invoke(this, e);
}
public TextMarkerValidatorService(TimeSpan updateInterval)
{
this.updateInterval = updateInterval;
}
public void UpdateText(string text)
{
// Ignore if text is the same
if (currentText == text) return;
// If previous task is not null, ignore
if (currentTask != null) return;
// Start a task to validate the text, and delay it by the update interval after the last update
currentTask = Task.Run(async () =>
{
await ValidateWithDelayAsync();
}).ContinueWith(_ =>
{
// Set task to null
currentTask = null;
// Set current text
currentText = text;
});
currentTask.SafeFireAndForget();
}
private void Validate()
{
}
private async Task ValidateWithDelayAsync(CancellationToken cancellationToken = default)
{
// Validate the text
Validate();
// Wait for the update interval
await Task.Delay(updateInterval, cancellationToken);
}
}

216
StabilityMatrix.Avalonia/Models/Inference/Prompt.cs

@ -0,0 +1,216 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using StabilityMatrix.Avalonia.Models.Inference.Tokens;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using TextMateSharp.Grammars;
namespace StabilityMatrix.Avalonia.Models.Inference;
public record Prompt
{
public required string RawText { get; init; }
public required ITokenizeLineResult TokenizeResult { get; init; }
private List<PromptExtraNetwork>? extraNetworks;
public IReadOnlyList<PromptExtraNetwork> ExtraNetworks => extraNetworks ??= GetExtraNetworks();
/// <summary>
/// Returns processed text suitable for sending to inference backend.
/// This excludes extra network (i.e. LORA) tokens.
/// </summary>
private string GetProcessedText()
{
// TODO
return RawText;
}
private List<PromptExtraNetwork> GetExtraNetworks()
{
// Parse tokens "meta.structure.network.prompt"
// "<": "punctuation.definition.network.begin.prompt"
// (type): "meta.embedded.network.type.prompt"
// ":": "punctuation.separator.variable.prompt"
// (content): "meta.embedded.network.model.prompt"
// ">": "punctuation.definition.network.end.prompt"
using var tokens = TokenizeResult.Tokens.Cast<IToken>().GetEnumerator();
// Store non-network tokens
var output = new StringBuilder();
// Store extra networks
var promptExtraNetworks = new List<PromptExtraNetwork>();
while (tokens.MoveNext())
{
var token = tokens.Current;
// Find start of network token, until then just add to output
if (!token.Scopes.Contains("punctuation.definition.network.begin.prompt"))
{
output.Append(RawText[token.StartIndex..token.EndIndex]);
continue;
}
// Expect next token to be network type
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(token.StartIndex, token.EndIndex);
}
var networkTypeToken = tokens.Current;
if (!networkTypeToken.Scopes.Contains("meta.embedded.network.type.prompt"))
{
throw PromptSyntaxError.Network_ExpectedType(
networkTypeToken.StartIndex, networkTypeToken.EndIndex);
}
var networkType = RawText[networkTypeToken.StartIndex..networkTypeToken.EndIndex];
// Match network type
var parsedNetworkType = networkType switch
{
"lora" => PromptExtraNetworkType.Lora,
"lycoris" => PromptExtraNetworkType.LyCORIS,
"embedding" => PromptExtraNetworkType.Embedding,
_ => throw PromptValidationError.Network_UnknownType(
networkTypeToken.StartIndex, networkTypeToken.EndIndex)
};
// Skip colon token
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(token.StartIndex, token.EndIndex);
}
// Ensure next token is colon
if (!tokens.Current.Scopes.Contains("punctuation.separator.variable.prompt"))
{
throw PromptSyntaxError.Network_ExpectedSeparator(
tokens.Current!.StartIndex, tokens.Current!.EndIndex);
}
// Get model name
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(token.StartIndex, token.EndIndex);
}
var modelNameToken = tokens.Current;
if (!tokens.Current.Scopes.Contains("meta.embedded.network.model.prompt"))
{
throw PromptSyntaxError.Network_ExpectedName(
tokens.Current!.StartIndex, tokens.Current!.EndIndex);
}
var modelName = RawText[modelNameToken.StartIndex..modelNameToken.EndIndex];
// Skip another colon token
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(token.StartIndex, token.EndIndex);
}
// If its a ending token instead, we can end here
if (tokens.Current.Scopes.Contains("punctuation.definition.network.end.prompt"))
{
promptExtraNetworks.Add(new PromptExtraNetwork
{
Type = parsedNetworkType,
Name = modelName
});
continue;
}
// Ensure next token is colon
if (!tokens.Current.Scopes.Contains("punctuation.separator.variable.prompt"))
{
throw PromptSyntaxError.Network_ExpectedSeparator(
tokens.Current!.StartIndex, tokens.Current!.EndIndex);
}
// Get model weight
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(token.StartIndex, token.EndIndex);
}
var modelWeightToken = tokens.Current;
if (!tokens.Current.Scopes.Contains("constant.numeric"))
{
throw PromptSyntaxError.Network_ExpectedWeight(
tokens.Current!.StartIndex, tokens.Current!.EndIndex);
}
var modelWeight = RawText[modelWeightToken.StartIndex..modelWeightToken.EndIndex];
// Convert to double
if (!double.TryParse(modelWeight, out var weight))
{
throw PromptValidationError.Network_InvalidWeight(
modelWeightToken.StartIndex, modelWeightToken.EndIndex);
}
// Expect end
if (!tokens.MoveNext())
{
throw PromptSyntaxError.UnexpectedEndOfText(token.StartIndex, token.EndIndex);
}
var endToken = tokens.Current;
if (!endToken.Scopes.Contains("punctuation.definition.network.end.prompt"))
{
throw PromptSyntaxError.UnexpectedEndOfText(
endToken.StartIndex, endToken.EndIndex);
}
// Add to output
promptExtraNetworks.Add(new PromptExtraNetwork
{
Type = parsedNetworkType,
Name = modelName,
ModelWeight = weight
});
}
return promptExtraNetworks;
}
public string GetDebugText()
{
var builder = new StringBuilder();
foreach (var token in TokenizeResult.Tokens)
{
// Get token text
var text = RawText[token.StartIndex..token.EndIndex];
// Format scope
var scopeStr = string.Join(", ", token.Scopes
.Where(s => s != "source.prompt")
.Select(s => s.EndsWith(".prompt") ? s.Remove(s.LastIndexOf(".prompt", StringComparison.Ordinal)) : s));
builder.AppendLine($"{text.ToRepr()} ({token.StartIndex}, {token.EndIndex})");
builder.AppendLine($" └─ {scopeStr}");
}
return builder.ToString();
}
public static Prompt FromRawText(string text, ITokenizerProvider tokenizer)
{
using var _ = new CodeTimer();
var result = tokenizer.TokenizeLine(text);
return new Prompt
{
RawText = text,
TokenizeResult = result
};
}
}

13
StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetwork.cs

@ -0,0 +1,13 @@
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens;
/// <summary>
/// Represents an extra network token in a prompt.
/// In format
/// </summary>
public record PromptExtraNetwork
{
public required PromptExtraNetworkType Type { get; init; }
public required string Name { get; init; }
public double? ModelWeight { get; init; }
public double? ClipWeight { get; init; }
}

16
StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetworkType.cs

@ -0,0 +1,16 @@
using System;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens;
[Flags]
public enum PromptExtraNetworkType
{
[ConvertTo<SharedFolderType>(SharedFolderType.Lora)]
Lora = 1 << 0,
[ConvertTo<SharedFolderType>(SharedFolderType.LyCORIS)]
LyCORIS = 1 << 1,
[ConvertTo<SharedFolderType>(SharedFolderType.TextualInversion)]
Embedding = 1 << 2
}

6
StabilityMatrix.Avalonia/Models/Inference/Tokens/TokenNames.cs

@ -0,0 +1,6 @@
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens;
public static class TokenNames
{
}

10
StabilityMatrix.Avalonia/Models/TagCompletion/CompletionType.cs

@ -0,0 +1,10 @@
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
/// <summary>
/// Type of completion requested.
/// </summary>
public enum CompletionType
{
Tag,
ExtraNetwork
}

8
StabilityMatrix.Avalonia/Models/TagCompletion/EditorCompletionRequest.cs

@ -0,0 +1,8 @@
using AvaloniaEdit.Document;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
public record EditorCompletionRequest : TextCompletionRequest
{
public required ISegment Segment { get; init; }
}

35
StabilityMatrix.Avalonia/Models/TagCompletion/ModelCompletionData.cs

@ -0,0 +1,35 @@
using System.IO;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Models.Inference.Tokens;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
public class ModelCompletionData : CompletionData
{
protected PromptExtraNetworkType NetworkType { get; }
/// <inheritdoc />
public ModelCompletionData(string text, PromptExtraNetworkType networkType) : base(text)
{
NetworkType = networkType;
// TODO: multi icons?
Icon = CompletionIcons.Model;
Description = networkType.GetStringValue();
}
public static ModelCompletionData FromLocalModel(LocalModelFile localModel, PromptExtraNetworkType networkType)
{
var displayName = Path.GetFileNameWithoutExtension(localModel.FileName);
return new ModelCompletionData(displayName, networkType)
{
ImageTitle = localModel.ConnectedModelInfo?.ModelName,
ImageSubtitle = localModel.ConnectedModelInfo?.VersionName,
ImageSource = localModel.PreviewImageFullPathGlobal is { } img
? new ImageSource(new FilePath(img))
: null
};
}
}

12
StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs

@ -0,0 +1,12 @@
using System.Collections.Generic;
using AvaloniaEdit.Document;
using StabilityMatrix.Avalonia.Models.Inference.Tokens;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
public record TextCompletionRequest
{
public required CompletionType Type { get; init; }
public required string Text { get; init; }
public PromptExtraNetworkType ExtraNetworkTypes { get; init; } = new();
}

13
StabilityMatrix.Core/Exceptions/PromptError.cs

@ -0,0 +1,13 @@
namespace StabilityMatrix.Core.Exceptions;
public abstract class PromptError : ApplicationException
{
public int TextOffset { get; }
public int TextEndOffset { get; }
protected PromptError(string message, int textOffset, int textEndOffset) : base(message)
{
TextOffset = textOffset;
TextEndOffset = textEndOffset;
}
}

24
StabilityMatrix.Core/Exceptions/PromptSyntaxError.cs

@ -0,0 +1,24 @@
namespace StabilityMatrix.Core.Exceptions;
public class PromptSyntaxError : PromptError
{
public static PromptSyntaxError Network_ExpectedSeparator(int textOffset, int textEndOffset) =>
new("Expected separator", textOffset, textEndOffset);
public static PromptSyntaxError Network_ExpectedType(int textOffset, int textEndOffset) =>
new("Expected network type", textOffset, textEndOffset);
public static PromptSyntaxError Network_ExpectedName(int textOffset, int textEndOffset) =>
new("Expected network name", textOffset, textEndOffset);
public static PromptSyntaxError Network_ExpectedWeight(int textOffset, int textEndOffset) =>
new("Expected network weight", textOffset, textEndOffset);
public static PromptSyntaxError UnexpectedEndOfText(int textOffset, int textEndOffset) =>
new("Unexpected end of text", textOffset, textEndOffset);
/// <inheritdoc />
public PromptSyntaxError(string message, int textOffset, int textEndOffset) : base(message, textOffset, textEndOffset)
{
}
}

15
StabilityMatrix.Core/Exceptions/PromptValidationError.cs

@ -0,0 +1,15 @@
namespace StabilityMatrix.Core.Exceptions;
public class PromptValidationError : PromptError
{
/// <inheritdoc />
public PromptValidationError(string message, int textOffset, int textEndOffset) : base(message, textOffset, textEndOffset)
{
}
public static PromptValidationError Network_UnknownType(int textOffset, int textEndOffset) =>
new("Unknown network type", textOffset, textEndOffset);
public static PromptSyntaxError Network_InvalidWeight(int textOffset, int textEndOffset) =>
new("Invalid network weight, could not be parsed as double", textOffset, textEndOffset);
}
Loading…
Cancel
Save