Browse Source

Change Inference to use downloaded images and custom file name formatting

pull/240/head
Ionite 1 year ago
parent
commit
a2c3acb952
No known key found for this signature in database
  1. 14
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  2. 47
      StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs
  3. 62
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  4. 194
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  5. 4
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
  6. 3
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

14
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.ImageSize = builder.Connections.LatentSize; builder.Connections.ImageSize = builder.Connections.LatentSize;
} }
var saveImage = builder.Nodes.AddNamedNode( var previewImage = builder.Nodes.AddNamedNode(
new NamedComfyNode("SaveImage") new NamedComfyNode("SaveImage")
{ {
ClassType = "SaveImage", ClassType = "PreviewImage",
Inputs = new Dictionary<string, object?> Inputs = new Dictionary<string, object?> { ["images"] = builder.Connections.Image }
{
["filename_prefix"] = "Inference/TextToImage",
["images"] = builder.Connections.Image
}
} }
); );
builder.Connections.OutputNodes.Add(saveImage); builder.Connections.OutputNodes.Add(previewImage);
return saveImage.Name; return previewImage.Name;
} }
} }

47
StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs

@ -13,50 +13,57 @@ public static class ImageProcessor
/// </summary> /// </summary>
public static (int rows, int columns) GetGridDimensionsFromImageCount(int count) public static (int rows, int columns) GetGridDimensionsFromImageCount(int count)
{ {
if (count <= 1) return (1, 1); if (count <= 1)
if (count == 2) return (1, 2); return (1, 1);
if (count == 2)
return (1, 2);
// Prefer one extra row over one extra column, // Prefer one extra row over one extra column,
// the row count will be the floor of the square root // the row count will be the floor of the square root
// and the column count will be floor of count / rows // and the column count will be floor of count / rows
var rows = (int) Math.Floor(Math.Sqrt(count)); var rows = (int)Math.Floor(Math.Sqrt(count));
var columns = (int) Math.Floor((double) count / rows); var columns = (int)Math.Floor((double)count / rows);
return (rows, columns); return (rows, columns);
} }
public static SKImage CreateImageGrid( public static SKImage CreateImageGrid(IReadOnlyList<SKImage> images, int spacing = 0)
IReadOnlyList<SKImage> images,
int spacing = 0)
{ {
if (images.Count == 0)
throw new ArgumentException("Must have at least one image");
var (rows, columns) = GetGridDimensionsFromImageCount(images.Count); var (rows, columns) = GetGridDimensionsFromImageCount(images.Count);
var singleWidth = images[0].Width; var singleWidth = images[0].Width;
var singleHeight = images[0].Height; var singleHeight = images[0].Height;
// Make output image // Make output image
using var output = new SKBitmap( using var output = new SKBitmap(
singleWidth * columns + spacing * (columns - 1), singleWidth * columns + spacing * (columns - 1),
singleHeight * rows + spacing * (rows - 1)); singleHeight * rows + spacing * (rows - 1)
);
// Draw images // Draw images
using var canvas = new SKCanvas(output); using var canvas = new SKCanvas(output);
foreach (var (row, column) in foreach (
Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))) var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))
)
{ {
// Stop if we have drawn all images // Stop if we have drawn all images
var index = row * columns + column; var index = row * columns + column;
if (index >= images.Count) break; if (index >= images.Count)
break;
// Get image // Get image
var image = images[index]; var image = images[index];
// Draw image // Draw image
var destination = new SKRect( var destination = new SKRect(
singleWidth * column + spacing * column, singleWidth * column + spacing * column,
singleHeight * row + spacing * row, singleHeight * row + spacing * row,
singleWidth * column + spacing * column + image.Width, singleWidth * column + spacing * column + image.Width,
singleHeight * row + spacing * row + image.Height); singleHeight * row + spacing * row + image.Height
);
canvas.DrawImage(image, destination); canvas.DrawImage(image, destination);
} }

62
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -1,5 +1,6 @@
using System; using System;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -345,6 +346,61 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken); return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
} }
private async Task MigrateLinksIfNeeded(PackagePair packagePair)
{
if (packagePair.InstalledPackage.FullPath is not { } packagePath)
{
throw new ArgumentException("Package path is null", nameof(packagePair));
}
var newDestination = settingsManager.ImagesInferenceDirectory;
// If new destination is a reparse point (like before, delete it first)
if (newDestination is { IsSymbolicLink: true, Info.LinkTarget: null })
{
logger.LogInformation("Deleting existing link target at {NewDir}", newDestination);
newDestination.Info.Attributes = FileAttributes.Normal;
await newDestination.DeleteAsync(false).ConfigureAwait(false);
}
newDestination.Create();
// For locally installed packages only
// Move all files in ./output/Inference to /Images/Inference and delete ./output/Inference
var legacyLinkSource = new DirectoryPath(packagePair.InstalledPackage.FullPath).JoinDir(
"output",
"Inference"
);
if (!legacyLinkSource.Exists)
{
return;
}
// Move files if not empty
if (legacyLinkSource.Info.EnumerateFiles().Any())
{
logger.LogInformation(
"Moving files from {LegacyDir} to {NewDir}",
legacyLinkSource,
newDestination
);
await FileTransfers
.MoveAllFilesAndDirectories(
legacyLinkSource,
newDestination,
overwriteIfHashMatches: true,
overwrite: false
)
.ConfigureAwait(false);
}
// Delete legacy link
logger.LogInformation("Deleting legacy link at {LegacyDir}", legacyLinkSource);
legacyLinkSource.Info.Attributes = FileAttributes.Normal;
await legacyLinkSource.DeleteAsync(false).ConfigureAwait(false);
}
/// <inheritdoc /> /// <inheritdoc />
public async Task ConnectAsync( public async Task ConnectAsync(
PackagePair packagePair, PackagePair packagePair,
@ -367,11 +423,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
logger.LogError(ex, "Error setting up completion provider"); logger.LogError(ex, "Error setting up completion provider");
}); });
// Setup image folder links await MigrateLinksIfNeeded(packagePair);
await comfyPackage.SetupInferenceOutputFolderLinks(
packagePair.InstalledPackage.FullPath
?? throw new InvalidOperationException("Package does not have a Path")
);
// Get user defined host and port // Get user defined host and port
var host = packagePair.InstalledPackage.GetLaunchArgsHost(); var host = packagePair.InstalledPackage.GetLaunchArgsHost();

194
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -3,11 +3,14 @@ using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics; using System.Diagnostics;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq; using System.Linq;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AsyncAwaitBestPractices; using AsyncAwaitBestPractices;
using Avalonia.Controls.Notifications;
using Avalonia.Media.Imaging;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using NLog; using NLog;
@ -27,6 +30,8 @@ using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Base; namespace StabilityMatrix.Avalonia.ViewModels.Base;
@ -41,6 +46,7 @@ public abstract partial class InferenceGenerationViewModelBase
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private readonly ServiceManager<ViewModelBase> vmFactory; private readonly ServiceManager<ViewModelBase> vmFactory;
@ -60,11 +66,13 @@ public abstract partial class InferenceGenerationViewModelBase
protected InferenceGenerationViewModelBase( protected InferenceGenerationViewModelBase(
ServiceManager<ViewModelBase> vmFactory, ServiceManager<ViewModelBase> vmFactory,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
INotificationService notificationService INotificationService notificationService,
ISettingsManager settingsManager
) )
: base(notificationService) : base(notificationService)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.vmFactory = vmFactory; this.vmFactory = vmFactory;
ClientManager = inferenceClientManager; ClientManager = inferenceClientManager;
@ -75,6 +83,100 @@ public abstract partial class InferenceGenerationViewModelBase
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
} }
/// <summary>
/// Write an image to the default output folder
/// </summary>
protected Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
)
{
var defaultOutputDir = settingsManager.ImagesInferenceDirectory;
defaultOutputDir.Create();
return WriteOutputImageAsync(
imageStream,
defaultOutputDir,
args,
batchNum,
batchTotal,
isGrid
);
}
/// <summary>
/// Write an image to an output folder
/// </summary>
protected async Task<FilePath> WriteOutputImageAsync(
Stream imageStream,
DirectoryPath outputDir,
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
)
{
var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat;
var formatProvider = new FileNameFormatProvider
{
GenerationParameters = args.Parameters,
ProjectType = args.Project?.ProjectType,
ProjectName = ProjectFile?.NameWithoutExtension
};
// Parse to format
if (
string.IsNullOrEmpty(formatTemplateStr)
|| !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format)
)
{
// Fallback to default
Logger.Warn(
"Failed to parse format template: {FormatTemplate}, using default",
formatTemplateStr
);
format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider);
}
if (isGrid)
{
format = format.WithGridPrefix();
}
if (batchNum >= 1 && batchTotal > 1)
{
format = format.WithBatchPostFix(batchNum, batchTotal);
}
var fileName = format.GetFileName() + ".png";
var file = outputDir.JoinFile(fileName);
// Until the file is free, keep adding _{i} to the end
for (var i = 0; i < 100; i++)
{
if (!file.Exists)
break;
file = outputDir.JoinFile($"{fileName}_{i + 1}");
}
// If that fails, append an 7-char uuid
if (file.Exists)
{
file = outputDir.JoinFile($"{fileName}_{Guid.NewGuid():N}"[..7]);
}
await using var fileStream = file.Info.OpenWrite();
await imageStream.CopyToAsync(fileStream);
return file;
}
/// <summary> /// <summary>
/// Builds the image generation prompt /// Builds the image generation prompt
/// </summary> /// </summary>
@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Wait for prompt to finish // Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken); await promptTask.Task.WaitAsync(cancellationToken);
Logger.Trace($"Prompt task {promptTask.Id} finished"); Logger.Debug($"Prompt task {promptTask.Id} finished");
// Get output images // Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync( var imageOutputs = await client.GetImagesForExecutedPromptAsync(
@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase
cancellationToken cancellationToken
); );
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
|| images is not { Count: > 0 }
)
{
// No images match
notificationService.Show(
"No output",
"Did not receive any output images",
NotificationType.Warning
);
return;
}
// Disable cancellation // Disable cancellation
await promptInterrupt.DisposeAsync(); await promptInterrupt.DisposeAsync();
@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGalleryCardViewModel.ImageSources.Clear(); ImageGalleryCardViewModel.ImageSources.Clear();
} }
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
)
{
// No images match
notificationService.Show("No output", "Did not receive any output images");
return;
}
await ProcessOutputImages(images, args); await ProcessOutputImages(images, args);
} }
finally finally
@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGenerationEventArgs args ImageGenerationEventArgs args
) )
{ {
var client = args.Client;
// Write metadata to images // Write metadata to images
var outputImagesBytes = new List<byte[]>();
var outputImages = new List<ImageSource>(); var outputImages = new List<ImageSource>();
foreach (
var (i, filePath) in images foreach (var (i, comfyImage) in images.Enumerate())
.Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
.Enumerate()
)
{ {
if (!filePath.Exists) Logger.Debug("Downloading image: {FileName}", comfyImage.FileName);
{ var imageStream = await client.GetImageStreamAsync(comfyImage);
Logger.Warn($"Image file {filePath} does not exist");
continue; using var ms = new MemoryStream();
} await imageStream.CopyToAsync(ms);
var imageArray = ms.ToArray();
outputImagesBytes.Add(imageArray);
var parameters = args.Parameters!; var parameters = args.Parameters!;
var project = args.Project!; var project = args.Project!;
@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase
); );
} }
var bytesWithMetadata = PngDataHelper.AddMetadata( var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project);
await filePath.ReadAllBytesAsync(),
parameters,
project
);
await using (var outputStream = filePath.Info.OpenWrite()) // Write using generated name
{ var filePath = await WriteOutputImageAsync(
await outputStream.WriteAsync(bytesWithMetadata); new MemoryStream(bytesWithMetadata),
await outputStream.FlushAsync(); args,
} i + 1,
images.Count
);
outputImages.Add(new ImageSource(filePath)); outputImages.Add(new ImageSource(filePath));
@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Download all images to make grid, if multiple // Download all images to make grid, if multiple
if (outputImages.Count > 1) if (outputImages.Count > 1)
{ {
var outputDir = outputImages[0].LocalFile!.Directory; var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray();
var loadedImages = outputImages
.Select(i => i.LocalFile)
.Where(f => f is { Exists: true })
.Select(f =>
{
using var stream = f!.Info.OpenRead();
return SKImage.FromEncodedData(stream);
})
.ToImmutableArray();
var project = args.Project!; var project = args.Project!;
@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase
); );
// Save to disk // Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name; var gridPath = await WriteOutputImageAsync(
var gridPath = outputDir!.JoinFile($"grid-{lastName}"); new MemoryStream(gridBytesWithMetadata),
args,
await using (var fileStream = gridPath.Info.OpenWrite()) isGrid: true
{ );
await fileStream.WriteAsync(gridBytesWithMetadata);
}
// Insert to start of images // Insert to start of images
var gridImage = new ImageSource(gridPath); var gridImage = new ImageSource(gridPath);

4
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs

@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
using Path = System.IO.Path; using Path = System.IO.Path;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration #pragma warning disable CS0657 // Not a valid attribute location for this declaration
@ -60,9 +61,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
public InferenceImageUpscaleViewModel( public InferenceImageUpscaleViewModel(
INotificationService notificationService, INotificationService notificationService,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory ServiceManager<ViewModelBase> vmFactory
) )
: base(vmFactory, inferenceClientManager, notificationService) : base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;

3
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -86,10 +86,11 @@ public class InferenceTextToImageViewModel
public InferenceTextToImageViewModel( public InferenceTextToImageViewModel(
INotificationService notificationService, INotificationService notificationService,
IInferenceClientManager inferenceClientManager, IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory, ServiceManager<ViewModelBase> vmFactory,
IModelIndexService modelIndexService IModelIndexService modelIndexService
) )
: base(vmFactory, inferenceClientManager, notificationService) : base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;
this.modelIndexService = modelIndexService; this.modelIndexService = modelIndexService;

Loading…
Cancel
Save