Browse Source

Change Inference to use downloaded images and custom file name formatting

pull/240/head
Ionite 1 year ago
parent
commit
a2c3acb952
No known key found for this signature in database
  1. 14
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 47
      StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs
  3. 62
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  4. 194
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  5. 4
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
  6. 3
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

14
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.ImageSize = builder.Connections.LatentSize;
}
var saveImage = builder.Nodes.AddNamedNode(
var previewImage = builder.Nodes.AddNamedNode(
new NamedComfyNode("SaveImage")
{
ClassType = "SaveImage",
Inputs = new Dictionary<string, object?>
{
["filename_prefix"] = "Inference/TextToImage",
["images"] = builder.Connections.Image
}
ClassType = "PreviewImage",
Inputs = new Dictionary<string, object?> { ["images"] = builder.Connections.Image }
}
);
builder.Connections.OutputNodes.Add(saveImage);
builder.Connections.OutputNodes.Add(previewImage);
return saveImage.Name;
return previewImage.Name;
}
}

47
StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs

@ -13,50 +13,57 @@ public static class ImageProcessor
/// </summary>
public static (int rows, int columns) GetGridDimensionsFromImageCount(int count)
{
if (count <= 1) return (1, 1);
if (count == 2) return (1, 2);
if (count <= 1)
return (1, 1);
if (count == 2)
return (1, 2);
// Prefer one extra row over one extra column,
// the row count will be the floor of the square root
// and the column count will be floor of count / rows
var rows = (int) Math.Floor(Math.Sqrt(count));
var columns = (int) Math.Floor((double) count / rows);
var rows = (int)Math.Floor(Math.Sqrt(count));
var columns = (int)Math.Floor((double)count / rows);
return (rows, columns);
}
public static SKImage CreateImageGrid(
IReadOnlyList<SKImage> images,
int spacing = 0)
public static SKImage CreateImageGrid(IReadOnlyList<SKImage> images, int spacing = 0)
{
if (images.Count == 0)
throw new ArgumentException("Must have at least one image");
var (rows, columns) = GetGridDimensionsFromImageCount(images.Count);
var singleWidth = images[0].Width;
var singleHeight = images[0].Height;
// Make output image
using var output = new SKBitmap(
singleWidth * columns + spacing * (columns - 1),
singleHeight * rows + spacing * (rows - 1));
singleWidth * columns + spacing * (columns - 1),
singleHeight * rows + spacing * (rows - 1)
);
// Draw images
using var canvas = new SKCanvas(output);
foreach (var (row, column) in
Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns)))
foreach (
var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))
)
{
// Stop if we have drawn all images
var index = row * columns + column;
if (index >= images.Count) break;
if (index >= images.Count)
break;
// Get image
var image = images[index];
// Draw image
var destination = new SKRect(
singleWidth * column + spacing * column,
singleHeight * row + spacing * row,
singleWidth * column + spacing * column + image.Width,
singleHeight * row + spacing * row + image.Height);
singleHeight * row + spacing * row + image.Height
);
canvas.DrawImage(image, destination);
}

62
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -1,5 +1,6 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@ -345,6 +346,61 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
}
private async Task MigrateLinksIfNeeded(PackagePair packagePair)
{
if (packagePair.InstalledPackage.FullPath is not { } packagePath)
{
throw new ArgumentException("Package path is null", nameof(packagePair));
}
var newDestination = settingsManager.ImagesInferenceDirectory;
// If new destination is a reparse point (like before, delete it first)
if (newDestination is { IsSymbolicLink: true, Info.LinkTarget: null })
{
logger.LogInformation("Deleting existing link target at {NewDir}", newDestination);
newDestination.Info.Attributes = FileAttributes.Normal;
await newDestination.DeleteAsync(false).ConfigureAwait(false);
}
newDestination.Create();
// For locally installed packages only
// Move all files in ./output/Inference to /Images/Inference and delete ./output/Inference
var legacyLinkSource = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir(
"output",
"Inference"
);
if (!legacyLinkSource.Exists)
{
return;
}
// Move files if not empty
if (legacyLinkSource.Info.EnumerateFiles().Any())
{
logger.LogInformation(
"Moving files from {LegacyDir} to {NewDir}",
legacyLinkSource,
newDestination
);
await FileTransfers
.MoveAllFilesAndDirectories(
legacyLinkSource,
newDestination,
overwriteIfHashMatches: true,
overwrite: false
)
.ConfigureAwait(false);
}
// Delete legacy link
logger.LogInformation("Deleting legacy link at {LegacyDir}", legacyLinkSource);
legacyLinkSource.Info.Attributes = FileAttributes.Normal;
await legacyLinkSource.DeleteAsync(false).ConfigureAwait(false);
}
/// <inheritdoc />
public async Task ConnectAsync(
PackagePair packagePair,
@ -367,11 +423,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
logger.LogError(ex, "Error setting up completion provider");
});
// Setup image folder links
await comfyPackage.SetupInferenceOutputFolderLinks(
packagePair.InstalledPackage.FullPath
?? throw new InvalidOperationException("Package does not have a Path")
);
await MigrateLinksIfNeeded(packagePair);
// Get user defined host and port
var host = packagePair.InstalledPackage.GetLaunchArgsHost();

194
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -3,11 +3,14 @@ using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls.Notifications;
using Avalonia.Media.Imaging;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using NLog;
@ -27,6 +30,8 @@ using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
@ -41,6 +46,7 @@ public abstract partial class InferenceGenerationViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
@ -60,11 +66,13 @@ public abstract partial class InferenceGenerationViewModelBase
protected InferenceGenerationViewModelBase(
ServiceManager<ViewModelBase> vmFactory,
IInferenceClientManager inferenceClientManager,
INotificationService notificationService
INotificationService notificationService,
ISettingsManager settingsManager
)
: base(notificationService)
{
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.vmFactory = vmFactory;
ClientManager = inferenceClientManager;
@ -75,6 +83,100 @@ public abstract partial class InferenceGenerationViewModelBase
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
/// <summary>
/// Write an image to the default output folder
/// </summary>
protected Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
)
{
var defaultOutputDir = settingsManager.ImagesInferenceDirectory;
defaultOutputDir.Create();
return WriteOutputImageAsync(
imageStream,
defaultOutputDir,
args,
batchNum,
batchTotal,
isGrid
);
}
/// <summary>
/// Write an image to an output folder
/// </summary>
protected async Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
DirectoryPath outputDir,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
)
{
var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat;
var formatProvider = new FileNameFormatProvider
{
GenerationParameters = args.Parameters,
ProjectType = args.Project?.ProjectType,
ProjectName = ProjectFile?.NameWithoutExtension
};
// Parse to format
if (
string.IsNullOrEmpty(formatTemplateStr)
|| !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format)
)
{
// Fallback to default
Logger.Warn(
"Failed to parse format template: {FormatTemplate}, using default",
formatTemplateStr
);
format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider);
}
if (isGrid)
{
format = format.WithGridPrefix();
}
if (batchNum >= 1 && batchTotal > 1)
{
format = format.WithBatchPostFix(batchNum, batchTotal);
}
var fileName = format.GetFileName() + ".png";
var file = outputDir.JoinFile(fileName);
// Until the file is free, keep adding _{i} to the end
for (var i = 0; i < 100; i++)
{
if (!file.Exists)
break;
file = outputDir.JoinFile($"{fileName}_{i + 1}");
}
// If that fails, append an 7-char uuid
if (file.Exists)
{
file = outputDir.JoinFile($"{fileName}_{Guid.NewGuid():N}"[..7]);
}
await using var fileStream = file.Info.OpenWrite();
await imageStream.CopyToAsync(fileStream);
return file;
}
/// <summary>
/// Builds the image generation prompt
/// </summary>
@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished");
Logger.Debug($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase
cancellationToken
);
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
|| images is not { Count: > 0 }
)
{
// No images match
notificationService.Show(
"No output",
"Did not receive any output images",
NotificationType.Warning
);
return;
}
// Disable cancellation
await promptInterrupt.DisposeAsync();
@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGalleryCardViewModel.ImageSources.Clear();
}
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
await ProcessOutputImages(images, args);
}
finally
@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGenerationEventArgs args
)
{
var client = args.Client;
// Write metadata to images
var outputImagesBytes = new List<byte[]>();
var outputImages = new List<ImageSource>();
foreach (
var (i, filePath) in images
.Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
.Enumerate()
)
foreach (var (i, comfyImage) in images.Enumerate())
{
if (!filePath.Exists)
{
Logger.Warn($"Image file {filePath} does not exist");
continue;
}
Logger.Debug("Downloading image: {FileName}", comfyImage.FileName);
var imageStream = await client.GetImageStreamAsync(comfyImage);
using var ms = new MemoryStream();
await imageStream.CopyToAsync(ms);
var imageArray = ms.ToArray();
outputImagesBytes.Add(imageArray);
var parameters = args.Parameters!;
var project = args.Project!;
@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase
);
}
var bytesWithMetadata = PngDataHelper.AddMetadata(
await filePath.ReadAllBytesAsync(),
parameters,
project
);
var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project);
await using (var outputStream = filePath.Info.OpenWrite())
{
await outputStream.WriteAsync(bytesWithMetadata);
await outputStream.FlushAsync();
}
// Write using generated name
var filePath = await WriteOutputImageAsync(
new MemoryStream(bytesWithMetadata),
args,
i + 1,
images.Count
);
outputImages.Add(new ImageSource(filePath));
@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var outputDir = outputImages[0].LocalFile!.Directory;
var loadedImages = outputImages
.Select(i => i.LocalFile)
.Where(f => f is { Exists: true })
.Select(f =>
{
using var stream = f!.Info.OpenRead();
return SKImage.FromEncodedData(stream);
})
.ToImmutableArray();
var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray();
var project = args.Project!;
@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase
);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = outputDir!.JoinFile($"grid-{lastName}");
await using (var fileStream = gridPath.Info.OpenWrite())
{
await fileStream.WriteAsync(gridBytesWithMetadata);
}
var gridPath = await WriteOutputImageAsync(
new MemoryStream(gridBytesWithMetadata),
args,
isGrid: true
);
// Insert to start of images
var gridImage = new ImageSource(gridPath);

4
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs

@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
using Path = System.IO.Path;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
@ -60,9 +61,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
public InferenceImageUpscaleViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory
)
: base(vmFactory, inferenceClientManager, notificationService)
: base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{
this.notificationService = notificationService;

3
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -86,10 +86,11 @@ public class InferenceTextToImageViewModel
public InferenceTextToImageViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory,
IModelIndexService modelIndexService
)
: base(vmFactory, inferenceClientManager, notificationService)
: base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{
this.notificationService = notificationService;
this.modelIndexService = modelIndexService;

Loading…
Cancel
Save