|
|
|
@ -3,11 +3,14 @@ using System.Collections.Generic;
|
|
|
|
|
using System.Collections.Immutable; |
|
|
|
|
using System.Diagnostics; |
|
|
|
|
using System.Diagnostics.CodeAnalysis; |
|
|
|
|
using System.IO; |
|
|
|
|
using System.Linq; |
|
|
|
|
using System.Text.Json.Serialization; |
|
|
|
|
using System.Threading; |
|
|
|
|
using System.Threading.Tasks; |
|
|
|
|
using AsyncAwaitBestPractices; |
|
|
|
|
using Avalonia.Controls.Notifications; |
|
|
|
|
using Avalonia.Media.Imaging; |
|
|
|
|
using Avalonia.Threading; |
|
|
|
|
using CommunityToolkit.Mvvm.Input; |
|
|
|
|
using NLog; |
|
|
|
@ -27,6 +30,8 @@ using StabilityMatrix.Core.Models;
|
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy; |
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; |
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; |
|
|
|
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
|
|
|
using StabilityMatrix.Core.Services; |
|
|
|
|
|
|
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Base; |
|
|
|
|
|
|
|
|
@ -41,6 +46,7 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
{ |
|
|
|
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
|
|
|
|
|
|
|
|
private readonly ISettingsManager settingsManager; |
|
|
|
|
private readonly INotificationService notificationService; |
|
|
|
|
private readonly ServiceManager<ViewModelBase> vmFactory; |
|
|
|
|
|
|
|
|
@ -60,11 +66,13 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
protected InferenceGenerationViewModelBase( |
|
|
|
|
ServiceManager<ViewModelBase> vmFactory, |
|
|
|
|
IInferenceClientManager inferenceClientManager, |
|
|
|
|
INotificationService notificationService |
|
|
|
|
INotificationService notificationService, |
|
|
|
|
ISettingsManager settingsManager |
|
|
|
|
) |
|
|
|
|
: base(notificationService) |
|
|
|
|
{ |
|
|
|
|
this.notificationService = notificationService; |
|
|
|
|
this.settingsManager = settingsManager; |
|
|
|
|
this.vmFactory = vmFactory; |
|
|
|
|
|
|
|
|
|
ClientManager = inferenceClientManager; |
|
|
|
@ -75,6 +83,100 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
/// Write an image to the default output folder |
|
|
|
|
/// </summary> |
|
|
|
|
protected Task<FilePath> WriteOutputImageAsync( |
|
|
|
|
Stream imageStream, |
|
|
|
|
ImageGenerationEventArgs args, |
|
|
|
|
int batchNum = 0, |
|
|
|
|
int batchTotal = 0, |
|
|
|
|
bool isGrid = false |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
var defaultOutputDir = settingsManager.ImagesInferenceDirectory; |
|
|
|
|
defaultOutputDir.Create(); |
|
|
|
|
|
|
|
|
|
return WriteOutputImageAsync( |
|
|
|
|
imageStream, |
|
|
|
|
defaultOutputDir, |
|
|
|
|
args, |
|
|
|
|
batchNum, |
|
|
|
|
batchTotal, |
|
|
|
|
isGrid |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
/// Write an image to an output folder |
|
|
|
|
/// </summary> |
|
|
|
|
protected async Task<FilePath> WriteOutputImageAsync( |
|
|
|
|
Stream imageStream, |
|
|
|
|
DirectoryPath outputDir, |
|
|
|
|
ImageGenerationEventArgs args, |
|
|
|
|
int batchNum = 0, |
|
|
|
|
int batchTotal = 0, |
|
|
|
|
bool isGrid = false |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat; |
|
|
|
|
|
|
|
|
|
var formatProvider = new FileNameFormatProvider |
|
|
|
|
{ |
|
|
|
|
GenerationParameters = args.Parameters, |
|
|
|
|
ProjectType = args.Project?.ProjectType, |
|
|
|
|
ProjectName = ProjectFile?.NameWithoutExtension |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
// Parse to format |
|
|
|
|
if ( |
|
|
|
|
string.IsNullOrEmpty(formatTemplateStr) |
|
|
|
|
|| !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format) |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
// Fallback to default |
|
|
|
|
Logger.Warn( |
|
|
|
|
"Failed to parse format template: {FormatTemplate}, using default", |
|
|
|
|
formatTemplateStr |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (isGrid) |
|
|
|
|
{ |
|
|
|
|
format = format.WithGridPrefix(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (batchNum >= 1 && batchTotal > 1) |
|
|
|
|
{ |
|
|
|
|
format = format.WithBatchPostFix(batchNum, batchTotal); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var fileName = format.GetFileName() + ".png"; |
|
|
|
|
var file = outputDir.JoinFile(fileName); |
|
|
|
|
|
|
|
|
|
// Until the file is free, keep adding _{i} to the end |
|
|
|
|
for (var i = 0; i < 100; i++) |
|
|
|
|
{ |
|
|
|
|
if (!file.Exists) |
|
|
|
|
break; |
|
|
|
|
|
|
|
|
|
file = outputDir.JoinFile($"{fileName}_{i + 1}"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If that fails, append an 7-char uuid |
|
|
|
|
if (file.Exists) |
|
|
|
|
{ |
|
|
|
|
file = outputDir.JoinFile($"{fileName}_{Guid.NewGuid():N}"[..7]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
await using var fileStream = file.Info.OpenWrite(); |
|
|
|
|
await imageStream.CopyToAsync(fileStream); |
|
|
|
|
|
|
|
|
|
return file; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
/// Builds the image generation prompt |
|
|
|
|
/// </summary> |
|
|
|
@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
|
|
|
|
|
// Wait for prompt to finish |
|
|
|
|
await promptTask.Task.WaitAsync(cancellationToken); |
|
|
|
|
Logger.Trace($"Prompt task {promptTask.Id} finished"); |
|
|
|
|
Logger.Debug($"Prompt task {promptTask.Id} finished"); |
|
|
|
|
|
|
|
|
|
// Get output images |
|
|
|
|
var imageOutputs = await client.GetImagesForExecutedPromptAsync( |
|
|
|
@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
cancellationToken |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
if ( |
|
|
|
|
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) |
|
|
|
|
|| images is not { Count: > 0 } |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
// No images match |
|
|
|
|
notificationService.Show( |
|
|
|
|
"No output", |
|
|
|
|
"Did not receive any output images", |
|
|
|
|
NotificationType.Warning |
|
|
|
|
); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Disable cancellation |
|
|
|
|
await promptInterrupt.DisposeAsync(); |
|
|
|
|
|
|
|
|
@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
ImageGalleryCardViewModel.ImageSources.Clear(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if ( |
|
|
|
|
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
// No images match |
|
|
|
|
notificationService.Show("No output", "Did not receive any output images"); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
await ProcessOutputImages(images, args); |
|
|
|
|
} |
|
|
|
|
finally |
|
|
|
@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
ImageGenerationEventArgs args |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
var client = args.Client; |
|
|
|
|
|
|
|
|
|
// Write metadata to images |
|
|
|
|
var outputImagesBytes = new List<byte[]>(); |
|
|
|
|
var outputImages = new List<ImageSource>(); |
|
|
|
|
foreach ( |
|
|
|
|
var (i, filePath) in images |
|
|
|
|
.Select(image => image.ToFilePath(args.Client.OutputImagesDir!)) |
|
|
|
|
.Enumerate() |
|
|
|
|
) |
|
|
|
|
{ |
|
|
|
|
if (!filePath.Exists) |
|
|
|
|
|
|
|
|
|
foreach (var (i, comfyImage) in images.Enumerate()) |
|
|
|
|
{ |
|
|
|
|
Logger.Warn($"Image file {filePath} does not exist"); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
Logger.Debug("Downloading image: {FileName}", comfyImage.FileName); |
|
|
|
|
var imageStream = await client.GetImageStreamAsync(comfyImage); |
|
|
|
|
|
|
|
|
|
using var ms = new MemoryStream(); |
|
|
|
|
await imageStream.CopyToAsync(ms); |
|
|
|
|
|
|
|
|
|
var imageArray = ms.ToArray(); |
|
|
|
|
outputImagesBytes.Add(imageArray); |
|
|
|
|
|
|
|
|
|
var parameters = args.Parameters!; |
|
|
|
|
var project = args.Project!; |
|
|
|
@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var bytesWithMetadata = PngDataHelper.AddMetadata( |
|
|
|
|
await filePath.ReadAllBytesAsync(), |
|
|
|
|
parameters, |
|
|
|
|
project |
|
|
|
|
); |
|
|
|
|
var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project); |
|
|
|
|
|
|
|
|
|
await using (var outputStream = filePath.Info.OpenWrite()) |
|
|
|
|
{ |
|
|
|
|
await outputStream.WriteAsync(bytesWithMetadata); |
|
|
|
|
await outputStream.FlushAsync(); |
|
|
|
|
} |
|
|
|
|
// Write using generated name |
|
|
|
|
var filePath = await WriteOutputImageAsync( |
|
|
|
|
new MemoryStream(bytesWithMetadata), |
|
|
|
|
args, |
|
|
|
|
i + 1, |
|
|
|
|
images.Count |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
outputImages.Add(new ImageSource(filePath)); |
|
|
|
|
|
|
|
|
@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
// Download all images to make grid, if multiple |
|
|
|
|
if (outputImages.Count > 1) |
|
|
|
|
{ |
|
|
|
|
var outputDir = outputImages[0].LocalFile!.Directory; |
|
|
|
|
|
|
|
|
|
var loadedImages = outputImages |
|
|
|
|
.Select(i => i.LocalFile) |
|
|
|
|
.Where(f => f is { Exists: true }) |
|
|
|
|
.Select(f => |
|
|
|
|
{ |
|
|
|
|
using var stream = f!.Info.OpenRead(); |
|
|
|
|
return SKImage.FromEncodedData(stream); |
|
|
|
|
}) |
|
|
|
|
.ToImmutableArray(); |
|
|
|
|
var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray(); |
|
|
|
|
|
|
|
|
|
var project = args.Project!; |
|
|
|
|
|
|
|
|
@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase
|
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
// Save to disk |
|
|
|
|
var lastName = outputImages.Last().LocalFile?.Info.Name; |
|
|
|
|
var gridPath = outputDir!.JoinFile($"grid-{lastName}"); |
|
|
|
|
|
|
|
|
|
await using (var fileStream = gridPath.Info.OpenWrite()) |
|
|
|
|
{ |
|
|
|
|
await fileStream.WriteAsync(gridBytesWithMetadata); |
|
|
|
|
} |
|
|
|
|
var gridPath = await WriteOutputImageAsync( |
|
|
|
|
new MemoryStream(gridBytesWithMetadata), |
|
|
|
|
args, |
|
|
|
|
isGrid: true |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
// Insert to start of images |
|
|
|
|
var gridImage = new ImageSource(gridPath); |
|
|
|
|