Browse Source

Merge branch 'dev' of https://github.com/ionite34/StabilityMatrix into selectable-image-buttons

pull/240/head
JT 1 year ago
parent
commit
4b54fd3de2
  1. 4
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs
  2. 22
      StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionListBox.cs
  3. 14
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  4. 31
      StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs
  5. 11
      StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs
  6. 73
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs
  7. 16
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs
  8. 192
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs
  9. 8
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs
  10. 45
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  11. 3
      StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
  12. 194
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  13. 4
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
  14. 3
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  15. 64
      StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
  16. 46
      StabilityMatrix.Avalonia/Views/SettingsPage.axaml
  17. 2
      StabilityMatrix.Core/Models/Database/LocalImageFile.cs
  18. 20
      StabilityMatrix.Core/Models/GenerationParameters.cs
  19. 5
      StabilityMatrix.Core/Models/Settings/Settings.cs
  20. 28
      StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs
  21. 24
      StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs

4
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionList.cs

@ -202,11 +202,11 @@ public class CompletionList : TemplatedControl
{ {
case Key.Down: case Key.Down:
e.Handled = true; e.Handled = true;
_listBox.SelectIndex(_listBox.SelectedIndex + 1); _listBox.SelectNextIndexWithLoop();
break; break;
case Key.Up: case Key.Up:
e.Handled = true; e.Handled = true;
_listBox.SelectIndex(_listBox.SelectedIndex - 1); _listBox.SelectPreviousIndexWithLoop();
break; break;
case Key.PageDown: case Key.PageDown:
e.Handled = true; e.Handled = true;

22
StabilityMatrix.Avalonia/Controls/CodeCompletion/CompletionListBox.cs

@ -91,6 +91,28 @@ public class CompletionListBox : ListBox
SelectedIndex = -1; SelectedIndex = -1;
} }
/// <summary>
/// Selects the next item. If the last item is already selected, selects the first item.
/// </summary>
public void SelectNextIndexWithLoop()
{
if (ItemCount <= 0)
return;
SelectIndex((SelectedIndex + 1) % ItemCount);
}
/// <summary>
/// Selects the previous item. If the first item is already selected, selects the last item.
/// </summary>
public void SelectPreviousIndexWithLoop()
{
if (ItemCount <= 0)
return;
SelectIndex((SelectedIndex - 1 + ItemCount) % ItemCount);
}
/// <summary> /// <summary>
/// Selects the item with the specified index and scrolls it into view. /// Selects the item with the specified index and scrolls it into view.
/// </summary> /// </summary>

14
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.ImageSize = builder.Connections.LatentSize; builder.Connections.ImageSize = builder.Connections.LatentSize;
} }
var saveImage = builder.Nodes.AddNamedNode( var previewImage = builder.Nodes.AddNamedNode(
new NamedComfyNode("SaveImage") new NamedComfyNode("SaveImage")
{ {
ClassType = "SaveImage", ClassType = "PreviewImage",
Inputs = new Dictionary<string, object?> Inputs = new Dictionary<string, object?> { ["images"] = builder.Connections.Image }
{
["filename_prefix"] = "Inference/TextToImage",
["images"] = builder.Connections.Image
}
} }
); );
builder.Connections.OutputNodes.Add(saveImage); builder.Connections.OutputNodes.Add(previewImage);
return saveImage.Name; return previewImage.Name;
} }
} }

31
StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs

@ -13,21 +13,24 @@ public static class ImageProcessor
/// </summary> /// </summary>
public static (int rows, int columns) GetGridDimensionsFromImageCount(int count) public static (int rows, int columns) GetGridDimensionsFromImageCount(int count)
{ {
if (count <= 1) return (1, 1); if (count <= 1)
if (count == 2) return (1, 2); return (1, 1);
if (count == 2)
return (1, 2);
// Prefer one extra row over one extra column, // Prefer one extra row over one extra column,
// the row count will be the floor of the square root // the row count will be the floor of the square root
// and the column count will be floor of count / rows // and the column count will be floor of count / rows
var rows = (int) Math.Floor(Math.Sqrt(count)); var rows = (int)Math.Floor(Math.Sqrt(count));
var columns = (int) Math.Floor((double) count / rows); var columns = (int)Math.Floor((double)count / rows);
return (rows, columns); return (rows, columns);
} }
public static SKImage CreateImageGrid( public static SKImage CreateImageGrid(IReadOnlyList<SKImage> images, int spacing = 0)
IReadOnlyList<SKImage> images,
int spacing = 0)
{ {
if (images.Count == 0)
throw new ArgumentException("Must have at least one image");
var (rows, columns) = GetGridDimensionsFromImageCount(images.Count); var (rows, columns) = GetGridDimensionsFromImageCount(images.Count);
var singleWidth = images[0].Width; var singleWidth = images[0].Width;
@ -36,17 +39,20 @@ public static class ImageProcessor
// Make output image // Make output image
using var output = new SKBitmap( using var output = new SKBitmap(
singleWidth * columns + spacing * (columns - 1), singleWidth * columns + spacing * (columns - 1),
singleHeight * rows + spacing * (rows - 1)); singleHeight * rows + spacing * (rows - 1)
);
// Draw images // Draw images
using var canvas = new SKCanvas(output); using var canvas = new SKCanvas(output);
foreach (var (row, column) in foreach (
Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))) var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))
)
{ {
// Stop if we have drawn all images // Stop if we have drawn all images
var index = row * columns + column; var index = row * columns + column;
if (index >= images.Count) break; if (index >= images.Count)
break;
// Get image // Get image
var image = images[index]; var image = images[index];
@ -56,7 +62,8 @@ public static class ImageProcessor
singleWidth * column + spacing * column, singleWidth * column + spacing * column,
singleHeight * row + spacing * row, singleHeight * row + spacing * row,
singleWidth * column + spacing * column + image.Width, singleWidth * column + spacing * column + image.Width,
singleHeight * row + spacing * row + image.Height); singleHeight * row + spacing * row + image.Height
);
canvas.DrawImage(image, destination); canvas.DrawImage(image, destination);
} }

11
StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs

@ -16,6 +16,17 @@ public static class PngDataHelper
private static readonly byte[] Text = { 0x74, 0x45, 0x58, 0x74 }; private static readonly byte[] Text = { 0x74, 0x45, 0x58, 0x74 };
private static readonly byte[] Iend = { 0x49, 0x45, 0x4E, 0x44 }; private static readonly byte[] Iend = { 0x49, 0x45, 0x4E, 0x44 };
public static byte[] AddMetadata(
Stream inputStream,
GenerationParameters generationParameters,
InferenceProjectDocument projectDocument
)
{
using var ms = new MemoryStream();
inputStream.CopyTo(ms);
return AddMetadata(ms.ToArray(), generationParameters, projectDocument);
}
public static byte[] AddMetadata( public static byte[] AddMetadata(
byte[] inputImage, byte[] inputImage,
GenerationParameters generationParameters, GenerationParameters generationParameters,

73
StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs

@ -0,0 +1,73 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
namespace StabilityMatrix.Avalonia.Models.Inference;
public record FileNameFormat
{
public string Template { get; }
public string Prefix { get; set; } = "";
public string Postfix { get; set; } = "";
public IReadOnlyList<FileNameFormatPart> Parts { get; }
private FileNameFormat(string template, IReadOnlyList<FileNameFormatPart> parts)
{
Template = template;
Parts = parts;
}
public FileNameFormat WithBatchPostFix(int current, int total)
{
return this with { Postfix = Postfix + $" ({current}-{total})" };
}
public FileNameFormat WithGridPrefix()
{
return this with { Prefix = Prefix + "Grid_" };
}
public string GetFileName()
{
return Prefix
+ string.Join(
"",
Parts.Select(
part => part.Match(constant => constant, substitution => substitution.Invoke())
)
)
+ Postfix;
}
public static FileNameFormat Parse(string template, FileNameFormatProvider provider)
{
var parts = provider.GetParts(template).ToImmutableArray();
return new FileNameFormat(template, parts);
}
public static bool TryParse(
string template,
FileNameFormatProvider provider,
[NotNullWhen(true)] out FileNameFormat? format
)
{
try
{
format = Parse(template, provider);
return true;
}
catch (ArgumentException)
{
format = null;
return false;
}
}
public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}";
}

16
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs

@ -0,0 +1,16 @@
using System;
using System.Runtime.InteropServices;
using CSharpDiscriminatedUnion.Attributes;
namespace StabilityMatrix.Avalonia.Models.Inference;
[GenerateDiscriminatedUnion(CaseFactoryPrefix = "From")]
[StructLayout(LayoutKind.Auto)]
public readonly partial struct FileNameFormatPart
{
[StructCase("Constant", isDefaultValue: true)]
private readonly string constant;
[StructCase("Substitution")]
private readonly Func<string?> substitution;
}

192
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs

@ -0,0 +1,192 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Text.RegularExpressions;
using Avalonia.Data;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Models.Inference;
public partial class FileNameFormatProvider
{
public GenerationParameters? GenerationParameters { get; init; }
public InferenceProjectType? ProjectType { get; init; }
public string? ProjectName { get; init; }
private Dictionary<string, Func<string?>>? _substitutions;
public Dictionary<string, Func<string?>> Substitutions =>
_substitutions ??= new Dictionary<string, Func<string?>>
{
{ "seed", () => GenerationParameters?.Seed.ToString() },
{ "prompt", () => GenerationParameters?.PositivePrompt },
{ "negative_prompt", () => GenerationParameters?.NegativePrompt },
{ "model_name", () => GenerationParameters?.ModelName },
{ "model_hash", () => GenerationParameters?.ModelHash },
{ "width", () => GenerationParameters?.Width.ToString() },
{ "height", () => GenerationParameters?.Height.ToString() },
{ "project_type", () => ProjectType?.GetStringValue() },
{ "project_name", () => ProjectName },
{ "date", () => DateTime.Now.ToString("yyyy-MM-dd") },
{ "time", () => DateTime.Now.ToString("HH-mm-ss") }
};
/// <summary>
/// Validate a format string
/// </summary>
/// <param name="format">Format string</param>
/// <exception cref="DataValidationException">Thrown if the format string contains an unknown variable</exception>
[Pure]
public ValidationResult Validate(string format)
{
var regex = BracketRegex();
var matches = regex.Matches(format);
var variables = matches.Select(m => m.Groups[1].Value);
foreach (var variableText in variables)
{
try
{
var (variable, _) = ExtractVariableAndSlice(variableText);
if (!Substitutions.ContainsKey(variable))
{
return new ValidationResult($"Unknown variable '{variable}'");
}
}
catch (Exception e)
{
return new ValidationResult($"Invalid variable '{variableText}': {e.Message}");
}
}
return ValidationResult.Success!;
}
public IEnumerable<FileNameFormatPart> GetParts(string template)
{
var regex = BracketRegex();
var matches = regex.Matches(template);
var parts = new List<FileNameFormatPart>();
// Loop through all parts of the string, including matches and non-matches
var currentIndex = 0;
foreach (var result in matches.Cast<Match>())
{
// If the match is not at the start of the string, add a constant part
if (result.Index != currentIndex)
{
var constant = template[currentIndex..result.Index];
parts.Add(FileNameFormatPart.FromConstant(constant));
currentIndex += constant.Length;
}
// Now we're at start of the current match, add the variable part
var (variable, slice) = ExtractVariableAndSlice(result.Groups[1].Value);
var substitution = Substitutions[variable];
// Slice string if necessary
if (slice is not null)
{
parts.Add(
FileNameFormatPart.FromSubstitution(() =>
{
var value = substitution();
if (value is null)
return null;
if (slice.End is null)
{
value = value[(slice.Start ?? 0)..];
}
else
{
var length =
Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0);
value = value.Substring(slice.Start ?? 0, length);
}
return value;
})
);
}
else
{
parts.Add(FileNameFormatPart.FromSubstitution(substitution));
}
currentIndex += result.Length;
}
// Add remaining as constant
if (currentIndex != template.Length)
{
var constant = template[currentIndex..];
parts.Add(FileNameFormatPart.FromConstant(constant));
}
return parts;
}
/// <summary>
/// Return a sample provider for UI preview
/// </summary>
public static FileNameFormatProvider GetSample()
{
return new FileNameFormatProvider
{
GenerationParameters = GenerationParameters.GetSample(),
ProjectType = InferenceProjectType.TextToImage,
ProjectName = "Sample Project"
};
}
/// <summary>
/// Extract variable and index from a combined string
/// </summary>
private static (string Variable, Slice? Slice) ExtractVariableAndSlice(string combined)
{
if (IndexRegex().Matches(combined).FirstOrDefault() is not { Success: true } match)
{
return (combined, null);
}
// Variable is everything before the match
var variable = combined[..match.Groups[0].Index];
var start = match.Groups["start"].Value;
var end = match.Groups["end"].Value;
var step = match.Groups["step"].Value;
var slice = new Slice(
string.IsNullOrEmpty(start) ? null : int.Parse(start),
string.IsNullOrEmpty(end) ? null : int.Parse(end),
string.IsNullOrEmpty(step) ? null : int.Parse(step)
);
return (variable, slice);
}
/// <summary>
/// Regex for matching contents within a curly brace.
/// </summary>
[GeneratedRegex(@"\{([a-z_:\d\[\]]+)\}")]
private static partial Regex BracketRegex();
/// <summary>
/// Regex for matching a Python-like array index.
/// </summary>
[GeneratedRegex(@"\[(?:(?<start>-?\d+)?)\:(?:(?<end>-?\d+)?)?(?:\:(?<step>-?\d+))?\]")]
private static partial Regex IndexRegex();
private record Slice(int? Start, int? End, int? Step);
}

8
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs

@ -0,0 +1,8 @@
namespace StabilityMatrix.Avalonia.Models.Inference;
public record FileNameFormatVar
{
public required string Variable { get; init; }
public string? Example { get; init; }
}

45
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -1,5 +1,6 @@
using System; using System;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -345,6 +346,44 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken); return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
} }
private async Task MigrateLinksIfNeeded(PackagePair packagePair)
{
if (packagePair.InstalledPackage.FullPath is not { } packagePath)
{
throw new ArgumentException("Package path is null", nameof(packagePair));
}
var inferenceDir = settingsManager.ImagesInferenceDirectory;
inferenceDir.Create();
// For locally installed packages only
// Delete ./output/Inference
var legacyInferenceLinkDir = new DirectoryPath(
packagePair.InstalledPackage.FullPath
).JoinDir("output", "Inference");
if (legacyInferenceLinkDir.Exists)
{
logger.LogInformation(
"Deleting legacy inference link at {LegacyDir}",
legacyInferenceLinkDir
);
if (legacyInferenceLinkDir.IsSymbolicLink)
{
await legacyInferenceLinkDir.DeleteAsync(false);
}
else
{
logger.LogWarning(
"Legacy inference link at {LegacyDir} is not a symbolic link, skipping",
legacyInferenceLinkDir
);
}
}
}
/// <inheritdoc /> /// <inheritdoc />
public async Task ConnectAsync( public async Task ConnectAsync(
PackagePair packagePair, PackagePair packagePair,
@ -367,11 +406,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
logger.LogError(ex, "Error setting up completion provider"); logger.LogError(ex, "Error setting up completion provider");
}); });
// Setup image folder links await MigrateLinksIfNeeded(packagePair);
await comfyPackage.SetupInferenceOutputFolderLinks(
packagePair.InstalledPackage.FullPath
?? throw new InvalidOperationException("Package does not have a Path")
);
// Get user defined host and port // Get user defined host and port
var host = packagePair.InstalledPackage.GetLaunchArgsHost(); var host = packagePair.InstalledPackage.GetLaunchArgsHost();

3
StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

@ -8,7 +8,7 @@
<ApplicationManifest>app.manifest</ApplicationManifest> <ApplicationManifest>app.manifest</ApplicationManifest>
<AvaloniaUseCompiledBindingsByDefault>true</AvaloniaUseCompiledBindingsByDefault> <AvaloniaUseCompiledBindingsByDefault>true</AvaloniaUseCompiledBindingsByDefault>
<ApplicationIcon>./Assets/Icon.ico</ApplicationIcon> <ApplicationIcon>./Assets/Icon.ico</ApplicationIcon>
<Version>2.5.3-dev.1</Version> <Version>2.6.0-dev.1</Version>
<InformationalVersion>$(Version)</InformationalVersion> <InformationalVersion>$(Version)</InformationalVersion>
<EnableWindowsTargeting>true</EnableWindowsTargeting> <EnableWindowsTargeting>true</EnableWindowsTargeting>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
@ -32,6 +32,7 @@
<PackageReference Include="Avalonia.Xaml.Behaviors" Version="11.0.2" /> <PackageReference Include="Avalonia.Xaml.Behaviors" Version="11.0.2" />
<PackageReference Include="AvaloniaEdit.TextMate" Version="11.0.0" /> <PackageReference Include="AvaloniaEdit.TextMate" Version="11.0.0" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="8.2.1" /> <PackageReference Include="CommunityToolkit.Mvvm" Version="8.2.1" />
<PackageReference Include="CSharpDiscriminatedUnion" Version="2.0.1" />
<PackageReference Include="DiscordRichPresence" Version="1.2.1.24" /> <PackageReference Include="DiscordRichPresence" Version="1.2.1.24" />
<PackageReference Include="Dock.Avalonia" Version="11.0.0.2" /> <PackageReference Include="Dock.Avalonia" Version="11.0.0.2" />
<PackageReference Include="Dock.Model.Avalonia" Version="11.0.0.2" /> <PackageReference Include="Dock.Model.Avalonia" Version="11.0.0.2" />

194
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -3,11 +3,13 @@ using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics; using System.Diagnostics;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq; using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AsyncAwaitBestPractices; using AsyncAwaitBestPractices;
using Avalonia.Controls.Notifications;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using NLog; using NLog;
@ -27,6 +29,8 @@ using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Base; namespace StabilityMatrix.Avalonia.ViewModels.Base;
@ -41,6 +45,7 @@ public abstract partial class InferenceGenerationViewModelBase
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory; private readonly ServiceManager<ViewModelBase> vmFactory;
@ -60,11 +65,13 @@ public abstract partial class InferenceGenerationViewModelBase
protected InferenceGenerationViewModelBase( protected InferenceGenerationViewModelBase(
ServiceManager<ViewModelBase> vmFactory, ServiceManager<ViewModelBase> vmFactory,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
INotificationService notificationService INotificationService notificationService,
ISettingsManager settingsManager
) )
: base(notificationService) : base(notificationService)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.vmFactory = vmFactory; this.vmFactory = vmFactory;
ClientManager = inferenceClientManager; ClientManager = inferenceClientManager;
@ -75,6 +82,101 @@ public abstract partial class InferenceGenerationViewModelBase
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
} }
/// <summary>
/// Write an image to the default output folder
/// </summary>
protected Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
)
{
var defaultOutputDir = settingsManager.ImagesInferenceDirectory;
defaultOutputDir.Create();
return WriteOutputImageAsync(
imageStream,
defaultOutputDir,
args,
batchNum,
batchTotal,
isGrid
);
}
/// <summary>
/// Write an image to an output folder
/// </summary>
protected async Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
DirectoryPath outputDir,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
)
{
var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat;
var formatProvider = new FileNameFormatProvider
{
GenerationParameters = args.Parameters,
ProjectType = args.Project?.ProjectType,
ProjectName = ProjectFile?.NameWithoutExtension
};
// Parse to format
if (
string.IsNullOrEmpty(formatTemplateStr)
|| !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format)
)
{
// Fallback to default
Logger.Warn(
"Failed to parse format template: {FormatTemplate}, using default",
formatTemplateStr
);
format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider);
}
if (isGrid)
{
format = format.WithGridPrefix();
}
if (batchNum >= 1 && batchTotal > 1)
{
format = format.WithBatchPostFix(batchNum, batchTotal);
}
var fileName = format.GetFileName();
var file = outputDir.JoinFile($"{fileName}.png");
// Until the file is free, keep adding _{i} to the end
for (var i = 0; i < 100; i++)
{
if (!file.Exists)
break;
file = outputDir.JoinFile($"{fileName}_{i + 1}.png");
}
// If that fails, append an 7-char uuid
if (file.Exists)
{
var uuid = Guid.NewGuid().ToString("N")[..7];
file = outputDir.JoinFile($"{fileName}_{uuid}.png");
}
await using var fileStream = file.Info.OpenWrite();
await imageStream.CopyToAsync(fileStream);
return file;
}
/// <summary> /// <summary>
/// Builds the image generation prompt /// Builds the image generation prompt
/// </summary> /// </summary>
@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Wait for prompt to finish // Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken); await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished"); Logger.Debug($"Prompt task {promptTask.Id} finished");
// Get output images // Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync( var imageOutputs = await client.GetImagesForExecutedPromptAsync(
@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase
cancellationToken cancellationToken
); );
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
|| images is not { Count: > 0 }
)
{
// No images match
notificationService.Show(
"No output",
"Did not receive any output images",
NotificationType.Warning
);
return;
}
// Disable cancellation // Disable cancellation
await promptInterrupt.DisposeAsync(); await promptInterrupt.DisposeAsync();
@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGalleryCardViewModel.ImageSources.Clear(); ImageGalleryCardViewModel.ImageSources.Clear();
} }
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
await ProcessOutputImages(images, args); await ProcessOutputImages(images, args);
} }
finally finally
@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGenerationEventArgs args ImageGenerationEventArgs args
) )
{ {
var client = args.Client;
// Write metadata to images // Write metadata to images
var outputImagesBytes = new List<byte[]>();
var outputImages = new List<ImageSource>(); var outputImages = new List<ImageSource>();
foreach (
var (i, filePath) in images foreach (var (i, comfyImage) in images.Enumerate())
.Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
.Enumerate()
)
{
if (!filePath.Exists)
{ {
Logger.Warn($"Image file {filePath} does not exist"); Logger.Debug("Downloading image: {FileName}", comfyImage.FileName);
continue; var imageStream = await client.GetImageStreamAsync(comfyImage);
}
using var ms = new MemoryStream();
await imageStream.CopyToAsync(ms);
var imageArray = ms.ToArray();
outputImagesBytes.Add(imageArray);
var parameters = args.Parameters!; var parameters = args.Parameters!;
var project = args.Project!; var project = args.Project!;
@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase
); );
} }
var bytesWithMetadata = PngDataHelper.AddMetadata( var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project);
await filePath.ReadAllBytesAsync(),
parameters,
project
);
await using (var outputStream = filePath.Info.OpenWrite()) // Write using generated name
{ var filePath = await WriteOutputImageAsync(
await outputStream.WriteAsync(bytesWithMetadata); new MemoryStream(bytesWithMetadata),
await outputStream.FlushAsync(); args,
} i + 1,
images.Count
);
outputImages.Add(new ImageSource(filePath)); outputImages.Add(new ImageSource(filePath));
@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Download all images to make grid, if multiple // Download all images to make grid, if multiple
if (outputImages.Count > 1) if (outputImages.Count > 1)
{ {
var outputDir = outputImages[0].LocalFile!.Directory; var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray();
var loadedImages = outputImages
.Select(i => i.LocalFile)
.Where(f => f is { Exists: true })
.Select(f =>
{
using var stream = f!.Info.OpenRead();
return SKImage.FromEncodedData(stream);
})
.ToImmutableArray();
var project = args.Project!; var project = args.Project!;
@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase
); );
// Save to disk // Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name; var gridPath = await WriteOutputImageAsync(
var gridPath = outputDir!.JoinFile($"grid-{lastName}"); new MemoryStream(gridBytesWithMetadata),
args,
await using (var fileStream = gridPath.Info.OpenWrite()) isGrid: true
{ );
await fileStream.WriteAsync(gridBytesWithMetadata);
}
// Insert to start of images // Insert to start of images
var gridImage = new ImageSource(gridPath); var gridImage = new ImageSource(gridPath);

4
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs

@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
using Path = System.IO.Path; using Path = System.IO.Path;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration #pragma warning disable CS0657 // Not a valid attribute location for this declaration
@ -60,9 +61,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
public InferenceImageUpscaleViewModel( public InferenceImageUpscaleViewModel(
INotificationService notificationService, INotificationService notificationService,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory ServiceManager<ViewModelBase> vmFactory
) )
: base(vmFactory, inferenceClientManager, notificationService) : base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;

3
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -86,10 +86,11 @@ public class InferenceTextToImageViewModel
public InferenceTextToImageViewModel( public InferenceTextToImageViewModel(
INotificationService notificationService, INotificationService notificationService,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory, ServiceManager<ViewModelBase> vmFactory,
IModelIndexService modelIndexService IModelIndexService modelIndexService
) )
: base(vmFactory, inferenceClientManager, notificationService) : base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;
this.modelIndexService = modelIndexService; this.modelIndexService = modelIndexService;

64
StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs

@ -3,10 +3,12 @@ using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.ComponentModel; using System.ComponentModel;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics; using System.Diagnostics;
using System.Globalization; using System.Globalization;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Reflection; using System.Reflection;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
@ -21,6 +23,7 @@ using Avalonia.Styling;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using DynamicData.Binding;
using FluentAvalonia.UI.Controls; using FluentAvalonia.UI.Controls;
using NLog; using NLog;
using SkiaSharp; using SkiaSharp;
@ -29,6 +32,7 @@ using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
@ -107,6 +111,25 @@ public partial class SettingsViewModel : PageViewModelBase
[ObservableProperty] [ObservableProperty]
private bool isCompletionRemoveUnderscoresEnabled = true; private bool isCompletionRemoveUnderscoresEnabled = true;
[ObservableProperty]
[CustomValidation(typeof(SettingsViewModel), nameof(ValidateOutputImageFileNameFormat))]
private string? outputImageFileNameFormat;
[ObservableProperty]
private string? outputImageFileNameFormatSample;
public IEnumerable<FileNameFormatVar> OutputImageFileNameFormatVars =>
FileNameFormatProvider
.GetSample()
.Substitutions.Select(
kv =>
new FileNameFormatVar
{
Variable = $"{{{kv.Key}}}",
Example = kv.Value.Invoke()
}
);
[ObservableProperty] [ObservableProperty]
private bool isImageViewerPixelGridEnabled = true; private bool isImageViewerPixelGridEnabled = true;
@ -201,6 +224,39 @@ public partial class SettingsViewModel : PageViewModelBase
true true
); );
this.WhenPropertyChanged(vm => vm.OutputImageFileNameFormat)
.Throttle(TimeSpan.FromMilliseconds(50))
.Subscribe(formatProperty =>
{
var provider = FileNameFormatProvider.GetSample();
var template = formatProperty.Value;
if (
!string.IsNullOrEmpty(template)
&& provider.Validate(template) == ValidationResult.Success
)
{
var format = FileNameFormat.Parse(template, provider);
OutputImageFileNameFormatSample = format.GetFileName() + ".png";
}
else
{
// Use default format if empty
var defaultFormat = FileNameFormat.Parse(
FileNameFormat.DefaultTemplate,
provider
);
OutputImageFileNameFormatSample = defaultFormat.GetFileName() + ".png";
}
});
settingsManager.RelayPropertyFor(
this,
vm => vm.OutputImageFileNameFormat,
settings => settings.InferenceOutputImageFileNameFormat,
true
);
settingsManager.RelayPropertyFor( settingsManager.RelayPropertyFor(
this, this,
vm => vm.IsImageViewerPixelGridEnabled, vm => vm.IsImageViewerPixelGridEnabled,
@ -225,6 +281,14 @@ public partial class SettingsViewModel : PageViewModelBase
UpdateAvailableTagCompletionCsvs(); UpdateAvailableTagCompletionCsvs();
} }
public static ValidationResult ValidateOutputImageFileNameFormat(
string format,
ValidationContext context
)
{
return FileNameFormatProvider.GetSample().Validate(format);
}
partial void OnSelectedThemeChanged(string? value) partial void OnSelectedThemeChanged(string? value)
{ {
// In case design / tests // In case design / tests

46
StabilityMatrix.Avalonia/Views/SettingsPage.axaml

@ -6,10 +6,15 @@
xmlns:controls="clr-namespace:StabilityMatrix.Avalonia.Controls" xmlns:controls="clr-namespace:StabilityMatrix.Avalonia.Controls"
xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:fluentIcons="clr-namespace:FluentIcons.FluentAvalonia;assembly=FluentIcons.FluentAvalonia"
xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData" xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData"
xmlns:ui="using:FluentAvalonia.UI.Controls" xmlns:ui="using:FluentAvalonia.UI.Controls"
xmlns:vm="clr-namespace:StabilityMatrix.Avalonia.ViewModels" xmlns:vm="clr-namespace:StabilityMatrix.Avalonia.ViewModels"
xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages" xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
xmlns:avaloniaEdit="https://github.com/avaloniaui/avaloniaedit"
xmlns:inference="clr-namespace:StabilityMatrix.Avalonia.Models.Inference"
xmlns:mdxaml="https://github.com/whistyun/Markdown.Avalonia.Tight"
Focusable="True"
d:DataContext="{x:Static mocks:DesignData.SettingsViewModel}" d:DataContext="{x:Static mocks:DesignData.SettingsViewModel}"
d:DesignHeight="700" d:DesignHeight="700"
d:DesignWidth="800" d:DesignWidth="800"
@ -83,10 +88,10 @@
</Grid> </Grid>
<!-- Inference UI --> <!-- Inference UI -->
<Grid Margin="0,8,0,0" RowDefinitions="auto,*,*"> <Grid Margin="0,8,0,0" RowDefinitions="auto,*,*,*">
<TextBlock <TextBlock
FontWeight="Medium" FontWeight="Medium"
Text="Inference UI" Text="Inference"
Margin="0,0,0,8" /> Margin="0,0,0,8" />
<!-- Auto Completion --> <!-- Auto Completion -->
<ui:SettingsExpander Grid.Row="1" <ui:SettingsExpander Grid.Row="1"
@ -155,6 +160,43 @@
</ui:SettingsExpanderItem.Footer> </ui:SettingsExpanderItem.Footer>
</ui:SettingsExpanderItem> </ui:SettingsExpanderItem>
</ui:SettingsExpander> </ui:SettingsExpander>
<!-- Output Image Files -->
<ui:SettingsExpander Grid.Row="3"
Header="Output Image Files"
Margin="8,0,8,4">
<ui:SettingsExpander.IconSource>
<fluentIcons:SymbolIconSource Symbol="TabDesktopImage"/>
</ui:SettingsExpander.IconSource>
<!-- File name pattern -->
<ui:SettingsExpanderItem
Content="File name pattern"
Description="{Binding OutputImageFileNameFormatSample}"
IconSource="Rename">
<ui:SettingsExpanderItem.Footer>
<TextBox
Name="OutputImageFileNameFormatTextBox"
Watermark="{x:Static inference:FileNameFormat.DefaultTemplate}"
FontSize="13"
MinWidth="150"
Text="{Binding OutputImageFileNameFormat}"
FontFamily="Cascadia Code,Consolas,Menlo,Monospace"/>
</ui:SettingsExpanderItem.Footer>
</ui:SettingsExpanderItem>
</ui:SettingsExpander>
<ui:TeachingTip
IsOpen="{Binding #OutputImageFileNameFormatTextBox.IsFocused}"
Target="{Binding #OutputImageFileNameFormatTextBox, Mode=OneWay}"
PreferredPlacement="Top"
Title="Format Variables"
Grid.Row="3">
<DataGrid
AutoGenerateColumns="True"
ItemsSource="{Binding OutputImageFileNameFormatVars}" />
<!--<mdxaml:MarkdownScrollViewer
Markdown="{Binding OutputImageFileNameFormatGuideMarkdown}"/>-->
</ui:TeachingTip>
</Grid> </Grid>
<!-- Environment Options --> <!-- Environment Options -->

2
StabilityMatrix.Core/Models/Database/LocalImageFile.cs

@ -126,6 +126,8 @@ public class LocalImageFile
GenerationParameters.TryParse(metadata, out genParams); GenerationParameters.TryParse(metadata, out genParams);
} }
filePath.Info.Refresh();
return new LocalImageFile return new LocalImageFile
{ {
RelativePath = relativePath, RelativePath = relativePath,

20
StabilityMatrix.Core/Models/GenerationParameters.cs

@ -126,6 +126,26 @@ public partial record GenerationParameters
return (sampler, scheduler); return (sampler, scheduler);
} }
/// <summary>
/// Return a sample parameters for UI preview
/// </summary>
public static GenerationParameters GetSample()
{
return new GenerationParameters
{
PositivePrompt = "(cat:1.2), by artist, detailed, [shaded]",
NegativePrompt = "blurry, jpg artifacts",
Steps = 30,
CfgScale = 7,
Width = 640,
Height = 896,
Seed = 124825529,
ModelName = "ExampleMix7",
ModelHash = "b899d188a1ac7356bfb9399b2277d5b21712aa360f8f9514fba6fcce021baff7",
Sampler = "DPM++ 2M Karras"
};
}
// Example: Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1 // Example: Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1
[GeneratedRegex( [GeneratedRegex(
"""^Steps: (?<Steps>\d+), Sampler: (?<Sampler>.+?), CFG scale: (?<CfgScale>\d+(\.\d+)?), Seed: (?<Seed>\d+), Size: (?<Width>\d+)x(?<Height>\d+), Model hash: (?<ModelHash>.+?), Model: (?<ModelName>.+)$""" """^Steps: (?<Steps>\d+), Sampler: (?<Sampler>.+?), CFG scale: (?<CfgScale>\d+(\.\d+)?), Seed: (?<Seed>\d+), Size: (?<Width>\d+)x(?<Height>\d+), Model hash: (?<ModelHash>.+?), Model: (?<ModelName>.+)$"""

5
StabilityMatrix.Core/Models/Settings/Settings.cs

@ -70,6 +70,11 @@ public class Settings
/// </summary> /// </summary>
public bool IsCompletionRemoveUnderscoresEnabled { get; set; } = true; public bool IsCompletionRemoveUnderscoresEnabled { get; set; } = true;
/// <summary>
/// Format for Inference output image file names
/// </summary>
public string? InferenceOutputImageFileNameFormat { get; set; }
/// <summary> /// <summary>
/// Whether the Inference Image Viewer shows pixel grids at high zoom levels /// Whether the Inference Image Viewer shows pixel grids at high zoom levels
/// </summary> /// </summary>

28
StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs

@ -0,0 +1,28 @@
using System.ComponentModel.DataAnnotations;
using StabilityMatrix.Avalonia.Models.Inference;
namespace StabilityMatrix.Tests.Avalonia;
[TestClass]
public class FileNameFormatProviderTests
{
[TestMethod]
public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow()
{
var provider = new FileNameFormatProvider();
var result = provider.Validate("{date}_{time}-{model_name}-{seed}");
Assert.AreEqual(ValidationResult.Success, result);
}
[TestMethod]
public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow()
{
var provider = new FileNameFormatProvider();
var result = provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}");
Assert.AreNotEqual(ValidationResult.Success, result);
Assert.AreEqual("Unknown variable 'invalid'", result.ErrorMessage);
}
}

24
StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs

@ -0,0 +1,24 @@
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Tests.Avalonia;
[TestClass]
public class FileNameFormatTests
{
[TestMethod]
public void TestFileNameFormatParse()
{
var provider = new FileNameFormatProvider
{
GenerationParameters = new GenerationParameters { Seed = 123 },
ProjectName = "uwu",
ProjectType = InferenceProjectType.TextToImage,
};
var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider);
Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName());
}
}
Loading…
Cancel
Save