Browse Source

Add generation metadata & smproj to inference output images

pull/165/head
JT 1 year ago
parent
commit
ff8ece34c5
  1. 88
      StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs
  2. 15
      StabilityMatrix.Avalonia/Models/Inference/GenerationParameters.cs
  3. 46
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  4. 1
      StabilityMatrix.Core/StabilityMatrix.Core.csproj

88
StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs

@ -0,0 +1,88 @@
using System;
using System.Linq;
using System.Text;
using System.Text.Json;
using Force.Crc32;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
namespace StabilityMatrix.Avalonia.Helpers;
public static class PngDataHelper
{
private static readonly byte[] Idat = { 0x49, 0x44, 0x41, 0x54 };
private static readonly byte[] Text = { 0x74, 0x45, 0x58, 0x74 };
private static readonly byte[] Iend = { 0x49, 0x45, 0x4E, 0x44 };
public static byte[] AddMetadata(
byte[] inputImage,
GenerationParameters generationParameters,
InferenceProjectDocument projectDocument
)
{
var imageWidthBytes = inputImage[0x10..0x14];
var imageHeightBytes = inputImage[0x14..0x18];
var imageWidth = BitConverter.ToInt32(imageWidthBytes.Reverse().ToArray());
var imageHeight = BitConverter.ToInt32(imageHeightBytes.Reverse().ToArray());
var idatIndex = SearchBytes(inputImage, Idat);
var iendIndex = SearchBytes(inputImage, Iend);
var textEndIndex = idatIndex - 4; // go back 4 cuz we don't want the length
var existingData = inputImage[..textEndIndex];
var smprojJson = JsonSerializer.Serialize(projectDocument);
var smprojChunk = GetTextChunk("smproj", smprojJson);
var paramsData =
$"{generationParameters.PositivePrompt}\nNegative prompt: {generationParameters.NegativePrompt}\n"
+ $"Steps: {generationParameters.Steps}, Sampler: {generationParameters.Sampler}, "
+ $"CFG scale: {generationParameters.CfgScale}, Seed: {generationParameters.Seed}, "
+ $"Size: {imageWidth}x{imageHeight}, "
+ $"Model hash: {generationParameters.ModelHash}, Model: {generationParameters.ModelName}";
var paramsChunk = GetTextChunk("parameters", paramsData);
// Go back 4 from the idat index because we need the length of the data
idatIndex -= 4;
// Go forward 8 from the iend index because we need the crc
iendIndex += 8;
var actualImageData = inputImage[idatIndex..iendIndex];
var finalImage = existingData
.Concat(smprojChunk)
.Concat(paramsChunk)
.Concat(actualImageData);
return finalImage.ToArray();
}
private static byte[] GetTextChunk(string key, string value)
{
var textData = $"{key}\0{value}";
var textDataLength = BitConverter.GetBytes(textData.Length).Reverse();
var textDataBytes = Text.Concat(Encoding.UTF8.GetBytes(textData)).ToArray();
var crc = BitConverter.GetBytes(Crc32Algorithm.Compute(textDataBytes));
return textDataLength.Concat(textDataBytes).Concat(crc).ToArray();
}
private static int SearchBytes(byte[] haystack, byte[] needle)
{
var limit = haystack.Length - needle.Length;
for (var i = 0; i <= limit; i++)
{
var k = 0;
for (; k < needle.Length; k++)
{
if (needle[k] != haystack[i + k])
break;
}
if (k == needle.Length)
return i;
}
return -1;
}
}

15
StabilityMatrix.Avalonia/Models/Inference/GenerationParameters.cs

@ -0,0 +1,15 @@
namespace StabilityMatrix.Avalonia.Models.Inference;
public class GenerationParameters
{
public string PositivePrompt { get; set; }
public string NegativePrompt { get; set; }
public int Steps { get; set; }
public string Sampler { get; set; }
public double CfgScale { get; set; }
public ulong Seed { get; set; }
public int Height { get; set; }
public int Width { get; set; }
public string ModelHash { get; set; }
public string ModelName { get; set; }
}

46
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq; using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
@ -17,6 +18,7 @@ using SkiaSharp;
using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
@ -354,6 +356,19 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
var (nodes, outputNodeNames) = BuildPrompt(overrides); var (nodes, outputNodeNames) = BuildPrompt(overrides);
var generationInfo = new GenerationParameters
{
Seed = (ulong)seedCard.Seed,
Steps = SamplerCardViewModel.Steps,
CfgScale = SamplerCardViewModel.CfgScale,
Sampler = SamplerCardViewModel.SelectedSampler?.Name,
ModelName = ModelCardViewModel.SelectedModelName,
// TODO: ModelHash
PositivePrompt = PromptCardViewModel.PromptDocument.Text,
NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text
};
var smproj = InferenceProjectDocument.FromLoadable(this);
// Connect preview image handler // Connect preview image handler
client.PreviewImageReceived += OnPreviewImageReceived; client.PreviewImageReceived += OnPreviewImageReceived;
@ -406,9 +421,26 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Use local file path if available, otherwise use remote URL // Use local file path if available, otherwise use remote URL
if (client.OutputImagesDir is { } outputPath) if (client.OutputImagesDir is { } outputPath)
{ {
outputImages = images! outputImages = new List<ImageSource>();
.Select(i => new ImageSource(i.ToFilePath(outputPath))) foreach (var image in images)
.ToList(); {
var filePath = image.ToFilePath(outputPath);
var fileStream = new BinaryReader(filePath.Info.OpenRead());
var bytes = fileStream.ReadBytes((int)filePath.Info.Length);
var bytesWithMetadata = PngDataHelper.AddMetadata(
bytes,
generationInfo,
smproj
);
fileStream.Close();
fileStream.Dispose();
await using var outputStream = filePath.Info.OpenWrite();
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
outputImages.Add(new ImageSource(filePath));
}
} }
else else
{ {
@ -425,6 +457,12 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
.ToImmutableArray(); .ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages); var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
generationInfo,
smproj
);
// Save to disk // Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name; var lastName = outputImages.Last().LocalFile?.Info.Name;
@ -432,7 +470,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
await using (var fileStream = gridPath.Info.OpenWrite()) await using (var fileStream = gridPath.Info.OpenWrite())
{ {
await fileStream.WriteAsync(grid.Encode().ToArray(), cancellationToken); await fileStream.WriteAsync(gridBytesWithMetadata, cancellationToken);
} }
// Insert to start of images // Insert to start of images

1
StabilityMatrix.Core/StabilityMatrix.Core.csproj

@ -16,6 +16,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="AsyncAwaitBestPractices" Version="6.0.6" /> <PackageReference Include="AsyncAwaitBestPractices" Version="6.0.6" />
<PackageReference Include="Blake3" Version="0.5.1" /> <PackageReference Include="Blake3" Version="0.5.1" />
<PackageReference Include="Crc32.NET" Version="1.2.0" />
<PackageReference Include="DeviceId" Version="6.3.0" /> <PackageReference Include="DeviceId" Version="6.3.0" />
<PackageReference Include="DeviceId.Linux" Version="6.3.0" /> <PackageReference Include="DeviceId.Linux" Version="6.3.0" />
<PackageReference Include="DeviceId.Mac" Version="6.2.0" /> <PackageReference Include="DeviceId.Mac" Version="6.2.0" />

Loading…
Cancel
Save