Browse Source

added webp to gif conversion & display in img2vid

pull/438/head
JT 11 months ago
parent
commit
ff4c347567
  1. 10
      StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs
  2. 1
      StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
  3. 149
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
  4. 54
      StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs
  5. 2
      StabilityMatrix.Avalonia/ViewModels/ConsoleViewModel.cs
  6. 23
      StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs
  7. 50
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs
  8. 14
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
  9. 39
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  10. 20
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  11. 20
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs
  12. 86
      StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs
  13. 5
      StabilityMatrix.Avalonia/Views/Inference/InferenceImageToVideoView.axaml
  14. 4
      StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs
  15. 14
      StabilityMatrix.Core/Animation/GifConverter.cs
  16. 1
      StabilityMatrix.Core/StabilityMatrix.Core.csproj
  17. 6
      StabilityMatrix.sln

10
StabilityMatrix.Avalonia/Models/Inference/ModuleApplyStepEventArgs.cs

@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.Models.Inference;
@ -29,18 +29,14 @@ public class ModuleApplyStepEventArgs : EventArgs
/// <summary>
/// Generation overrides (like hires fix generate, current seed generate, etc.)
/// </summary>
public IReadOnlyDictionary<Type, bool> IsEnabledOverrides { get; init; } =
new Dictionary<Type, bool>();
public IReadOnlyDictionary<Type, bool> IsEnabledOverrides { get; init; } = new Dictionary<Type, bool>();
public class ModuleApplyStepTemporaryArgs
{
/// <summary>
/// Temporary conditioning apply step, used by samplers to apply control net.
/// </summary>
public (
ConditioningNodeConnection Positive,
ConditioningNodeConnection Negative
)? Conditioning { get; set; }
public (ConditioningNodeConnection Positive, ConditioningNodeConnection Negative)? Conditioning { get; set; }
/// <summary>
/// Temporary refiner conditioning apply step, used by samplers to apply control net.

1
StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

@ -100,6 +100,7 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Avalonia.Gif\Avalonia.Gif.csproj" />
<ProjectReference Include="..\StabilityMatrix.Core\StabilityMatrix.Core.csproj" />
</ItemGroup>

149
StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs

@ -23,6 +23,7 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Animation;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
@ -41,9 +42,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
/// This includes a progress reporter, image output view model, and generation virtual methods.
/// </summary>
[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
public abstract partial class InferenceGenerationViewModelBase
: InferenceTabViewModelBase,
IImageGalleryComponent
public abstract partial class InferenceGenerationViewModelBase : InferenceTabViewModelBase, IImageGalleryComponent
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -92,20 +91,14 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
bool isGrid = false,
string fileExtension = "png"
)
{
var defaultOutputDir = settingsManager.ImagesInferenceDirectory;
defaultOutputDir.Create();
return WriteOutputImageAsync(
imageStream,
defaultOutputDir,
args,
batchNum,
batchTotal,
isGrid
);
return WriteOutputImageAsync(imageStream, defaultOutputDir, args, batchNum, batchTotal, isGrid, fileExtension);
}
/// <summary>
@ -117,7 +110,8 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGenerationEventArgs args,
int batchNum = 0,
int batchTotal = 0,
bool isGrid = false
bool isGrid = false,
string fileExtension = "png"
)
{
var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat;
@ -136,10 +130,7 @@ public abstract partial class InferenceGenerationViewModelBase
)
{
// Fallback to default
Logger.Warn(
"Failed to parse format template: {FormatTemplate}, using default",
formatTemplateStr
);
Logger.Warn("Failed to parse format template: {FormatTemplate}, using default", formatTemplateStr);
format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider);
}
@ -155,7 +146,7 @@ public abstract partial class InferenceGenerationViewModelBase
}
var fileName = format.GetFileName();
var file = outputDir.JoinFile($"{fileName}.png");
var file = outputDir.JoinFile($"{fileName}.{fileExtension}");
// Until the file is free, keep adding _{i} to the end
for (var i = 0; i < 100; i++)
@ -163,14 +154,14 @@ public abstract partial class InferenceGenerationViewModelBase
if (!file.Exists)
break;
file = outputDir.JoinFile($"{fileName}_{i + 1}.png");
file = outputDir.JoinFile($"{fileName}_{i + 1}.{fileExtension}");
}
// If that fails, append an 7-char uuid
if (file.Exists)
{
var uuid = Guid.NewGuid().ToString("N")[..7];
file = outputDir.JoinFile($"{fileName}_{uuid}.png");
file = outputDir.JoinFile($"{fileName}_{uuid}.{fileExtension}");
}
await using var fileStream = file.Info.OpenWrite();
@ -200,11 +191,7 @@ public abstract partial class InferenceGenerationViewModelBase
{
var uploadName = await image.GetHashGuidFileNameAsync();
Logger.Debug(
"Uploading image {FileName} as {UploadName}",
localFile.Name,
uploadName
);
Logger.Debug("Uploading image {FileName} as {UploadName}", localFile.Name, uploadName);
// For pngs, strip metadata since Pillow can't handle some valid files?
if (localFile.Info.Extension.Equals(".png", StringComparison.OrdinalIgnoreCase))
@ -228,10 +215,7 @@ public abstract partial class InferenceGenerationViewModelBase
/// Runs a generation task
/// </summary>
/// <exception cref="InvalidOperationException">Thrown if args.Parameters or args.Project are null</exception>
protected async Task RunGeneration(
ImageGenerationEventArgs args,
CancellationToken cancellationToken
)
protected async Task RunGeneration(ImageGenerationEventArgs args, CancellationToken cancellationToken)
{
var client = args.Client;
var nodes = args.Nodes;
@ -311,32 +295,18 @@ public abstract partial class InferenceGenerationViewModelBase
{
Logger.Warn(e, "Comfy node exception while queuing prompt");
await DialogHelper
.CreateJsonDialog(
e.JsonData,
"Comfy Error",
"Node execution encountered an error"
)
.CreateJsonDialog(e.JsonData, "Comfy Error", "Node execution encountered an error")
.ShowAsync();
return;
}
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
promptTask.Id,
cancellationToken
);
var imageOutputs = await client.GetImagesForExecutedPromptAsync(promptTask.Id, cancellationToken);
if (
!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
|| images is not { Count: > 0 }
)
if (!imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is not { Count: > 0 })
{
// No images match
notificationService.Show(
"No output",
"Did not receive any output images",
NotificationType.Warning
);
notificationService.Show("No output", "Did not receive any output images", NotificationType.Warning);
return;
}
@ -369,10 +339,7 @@ public abstract partial class InferenceGenerationViewModelBase
/// <summary>
/// Handles image output metadata for generation runs
/// </summary>
private async Task ProcessOutputImages(
IReadOnlyCollection<ComfyImage> images,
ImageGenerationEventArgs args
)
private async Task ProcessOutputImages(IReadOnlyCollection<ComfyImage> images, ImageGenerationEventArgs args)
{
var client = args.Client;
@ -395,10 +362,7 @@ public abstract partial class InferenceGenerationViewModelBase
var project = args.Project!;
// Lock seed
project.TryUpdateModel<SeedCardModel>(
"Seed",
model => model with { IsRandomizeEnabled = false }
);
project.TryUpdateModel<SeedCardModel>("Seed", model => model with { IsRandomizeEnabled = false });
// Seed and batch override for batches
if (images.Count > 1 && project.ProjectType is InferenceProjectType.TextToImage)
@ -433,6 +397,29 @@ public abstract partial class InferenceGenerationViewModelBase
outputImages.Add(new ImageSource(filePath));
EventManager.Instance.OnImageFileAdded(filePath);
}
else if (comfyImage.FileName.EndsWith(".webp"))
{
// Write using generated name
var webpFilePath = await WriteOutputImageAsync(
new MemoryStream(imageArray),
args,
i + 1,
images.Count,
fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "")
);
// convert to gif
await GifConverter.ConvertWebpToGif(webpFilePath);
var gifFilePath = webpFilePath.ToString().Replace(".webp", ".gif");
if (File.Exists(gifFilePath))
{
// delete webp
File.Delete(webpFilePath);
}
outputImages.Add(new ImageSource(gifFilePath));
EventManager.Instance.OnImageFileAdded(gifFilePath);
}
else
{
// Write using generated name
@ -440,7 +427,8 @@ public abstract partial class InferenceGenerationViewModelBase
new MemoryStream(imageArray),
args,
i + 1,
images.Count
images.Count,
fileExtension: Path.GetExtension(comfyImage.FileName).Replace(".", "")
);
outputImages.Add(new ImageSource(filePath));
@ -456,25 +444,14 @@ public abstract partial class InferenceGenerationViewModelBase
var project = args.Project!;
// Lock seed
project.TryUpdateModel<SeedCardModel>(
"Seed",
model => model with { IsRandomizeEnabled = false }
);
project.TryUpdateModel<SeedCardModel>("Seed", model => model with { IsRandomizeEnabled = false });
var grid = ImageProcessor.CreateImageGrid(loadedImages);
var gridBytes = grid.Encode().ToArray();
var gridBytesWithMetadata = PngDataHelper.AddMetadata(
gridBytes,
args.Parameters!,
args.Project!
);
var gridBytesWithMetadata = PngDataHelper.AddMetadata(gridBytes, args.Parameters!, args.Project!);
// Save to disk
var gridPath = await WriteOutputImageAsync(
new MemoryStream(gridBytesWithMetadata),
args,
isGrid: true
);
var gridPath = await WriteOutputImageAsync(new MemoryStream(gridBytesWithMetadata), args, isGrid: true);
// Insert to start of images
var gridImage = new ImageSource(gridPath);
@ -497,10 +474,7 @@ public abstract partial class InferenceGenerationViewModelBase
/// <summary>
/// Implementation for Generate Image
/// </summary>
protected virtual Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
protected virtual Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
@ -511,10 +485,7 @@ public abstract partial class InferenceGenerationViewModelBase
/// <param name="options">Optional overrides (side buttons)</param>
/// <param name="cancellationToken">Cancellation token</param>
[RelayCommand(IncludeCancelCommand = true, FlowExceptionsToTaskScheduler = true)]
private async Task GenerateImage(
GenerateFlags options = default,
CancellationToken cancellationToken = default
)
private async Task GenerateImage(GenerateFlags options = default, CancellationToken cancellationToken = default)
{
var overrides = GenerateOverrides.FromFlags(options);
@ -555,20 +526,18 @@ public abstract partial class InferenceGenerationViewModelBase
/// Handles the progress update received event from the websocket.
/// Updates the progress view model.
/// </summary>
protected virtual void OnProgressUpdateReceived(
object? sender,
ComfyProgressUpdateEventArgs args
)
protected virtual void OnProgressUpdateReceived(object? sender, ComfyProgressUpdateEventArgs args)
{
Dispatcher.UIThread.Post(() =>
Dispatcher
.UIThread
.Post(() =>
{
OutputProgress.Value = args.Value;
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
$"({args.Value} / {args.Maximum})" + (args.RunningNode != null ? $" {args.RunningNode}" : "");
});
}
@ -594,7 +563,9 @@ public abstract partial class InferenceGenerationViewModelBase
return;
}
Dispatcher.UIThread.Post(() =>
Dispatcher
.UIThread
.Post(() =>
{
OutputProgress.IsIndeterminate = true;
OutputProgress.Value = 100;
@ -628,11 +599,7 @@ public abstract partial class InferenceGenerationViewModelBase
overrides[typeof(HiresFixModule)] = args.Overrides.IsHiresFixEnabled.Value;
}
return new ModuleApplyStepEventArgs
{
Builder = args.Builder,
IsEnabledOverrides = overrides
};
return new ModuleApplyStepEventArgs { Builder = args.Builder, IsEnabledOverrides = overrides };
}
}
}

54
StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs

@ -28,24 +28,16 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private static readonly Type[] SerializerIgnoredTypes =
{
typeof(ICommand),
typeof(IRelayCommand)
};
private static readonly Type[] SerializerIgnoredTypes = { typeof(ICommand), typeof(IRelayCommand) };
private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) };
private static readonly JsonSerializerOptions SerializerOptions =
new() { IgnoreReadOnlyProperties = true };
private static readonly JsonSerializerOptions SerializerOptions = new() { IgnoreReadOnlyProperties = true };
private static bool ShouldIgnoreProperty(PropertyInfo property)
{
// Skip if read-only and not IJsonLoadableState
if (
property.SetMethod is null
&& !typeof(IJsonLoadableState).IsAssignableFrom(property.PropertyType)
)
if (property.SetMethod is null && !typeof(IJsonLoadableState).IsAssignableFrom(property.PropertyType))
{
Logger.ConditionalTrace("Skipping {Property} - read-only", property.Name);
return true;
@ -107,11 +99,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{
// Get all of our properties using reflection
var properties = GetType().GetProperties();
Logger.ConditionalTrace(
"Serializing {Type} with {Count} properties",
GetType(),
properties.Length
);
Logger.ConditionalTrace("Serializing {Type} with {Count} properties", GetType(), properties.Length);
foreach (var property in properties)
{
@ -119,9 +107,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
// If JsonPropertyName provided, use that as the key
if (
property
.GetCustomAttributes(typeof(JsonPropertyNameAttribute), true)
.FirstOrDefault()
property.GetCustomAttributes(typeof(JsonPropertyNameAttribute), true).FirstOrDefault()
is JsonPropertyNameAttribute jsonPropertyName
)
{
@ -168,10 +154,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
if (property.GetValue(this) is not IJsonLoadableState propertyValue)
{
// If null, it must have a default constructor
if (
property.PropertyType.GetConstructor(Type.EmptyTypes)
is not { } constructorInfo
)
if (property.PropertyType.GetConstructor(Type.EmptyTypes) is not { } constructorInfo)
{
throw new InvalidOperationException(
$"Property {property.Name} is IJsonLoadableState but current object is null and has no default constructor"
@ -188,11 +171,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
}
else
{
Logger.ConditionalTrace(
"Loading {Property} ({Type})",
property.Name,
property.PropertyType
);
Logger.ConditionalTrace("Loading {Property} ({Type})", property.Name, property.PropertyType);
var propertyValue = value.Deserialize(property.PropertyType, SerializerOptions);
property.SetValue(this, propertyValue);
@ -216,11 +195,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{
// Get all of our properties using reflection.
var properties = GetType().GetProperties();
Logger.ConditionalTrace(
"Serializing {Type} with {Count} properties",
GetType(),
properties.Length
);
Logger.ConditionalTrace("Serializing {Type} with {Count} properties", GetType(), properties.Length);
// Create a JSON object to store the state.
var state = new JsonObject();
@ -237,9 +212,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
// If JsonPropertyName provided, use that as the key.
if (
property
.GetCustomAttributes(typeof(JsonPropertyNameAttribute), true)
.FirstOrDefault()
property.GetCustomAttributes(typeof(JsonPropertyNameAttribute), true).FirstOrDefault()
is JsonPropertyNameAttribute jsonPropertyName
)
{
@ -270,11 +243,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
}
else
{
Logger.ConditionalTrace(
"Serializing {Property} ({Type})",
property.Name,
property.PropertyType
);
Logger.ConditionalTrace("Serializing {Property} ({Type})", property.Name, property.PropertyType);
var value = property.GetValue(this);
if (value is not null)
{
@ -297,8 +266,7 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
protected static JsonObject SerializeModel<T>(T model)
{
var node = JsonSerializer.SerializeToNode(model);
return node?.AsObject()
?? throw new NullReferenceException("Failed to serialize state to JSON object.");
return node?.AsObject() ?? throw new NullReferenceException("Failed to serialize state to JSON object.");
}
/// <summary>

2
StabilityMatrix.Avalonia/ViewModels/ConsoleViewModel.cs

@ -7,9 +7,9 @@ using System.Threading.Tasks.Dataflow;
using Avalonia.Threading;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using NLog;
using Nito.AsyncEx;
using Nito.AsyncEx.Synchronous;
using NLog;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Processes;

23
StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs

@ -126,10 +126,7 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
ShowProgressBar = true;
IsProgressIndeterminate = true;
UpdateText = string.Format(
Resources.TextTemplate_UpdatingPackage,
Resources.Label_StabilityMatrix
);
UpdateText = string.Format(Resources.TextTemplate_UpdatingPackage, Resources.Label_StabilityMatrix);
try
{
@ -173,10 +170,7 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
UpdateText = "Getting a few things ready...";
await using (new MinimumDelay(500, 1000))
{
Process.Start(
UpdateHelper.ExecutablePath,
$"--wait-for-exit-pid {Environment.ProcessId}"
);
Process.Start(UpdateHelper.ExecutablePath, $"--wait-for-exit-pid {Environment.ProcessId}");
}
UpdateText = "Update complete. Restarting Stability Matrix in 3 seconds...";
@ -262,15 +256,10 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
// Join all blocks until and excluding the current version
// If we're on a pre-release, include the current release
var currentVersionBlock = results.FindIndex(
x => x.Version == currentVersion.WithoutMetadata()
);
var currentVersionBlock = results.FindIndex(x => x.Version == currentVersion.WithoutMetadata());
// For mismatching build metadata, add one
if (
currentVersionBlock != -1
&& results[currentVersionBlock].Version?.Metadata != currentVersion.Metadata
)
if (currentVersionBlock != -1 && results[currentVersionBlock].Version?.Metadata != currentVersion.Metadata)
{
currentVersionBlock++;
}
@ -278,9 +267,7 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
// Support for previous pre-release without changelogs
if (currentVersionBlock == -1)
{
currentVersionBlock = results.FindIndex(
x => x.Version == currentVersion.WithoutPrereleaseOrMetadata()
);
currentVersionBlock = results.FindIndex(x => x.Version == currentVersion.WithoutPrereleaseOrMetadata());
// Add 1 if found to include the current release
if (currentVersionBlock != -1)

50
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToVideoViewModel.cs

@ -7,18 +7,20 @@ using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using NLog;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Avalonia.ViewModels.Inference.Video;
using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
@ -28,9 +30,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceImageToVideoView), persistent: true)]
[ManagedService]
[Transient]
public class InferenceImageToVideoViewModel
: InferenceGenerationViewModelBase,
IParametersLoadableState
public partial class InferenceImageToVideoViewModel : InferenceGenerationViewModelBase, IParametersLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -61,6 +61,10 @@ public class InferenceImageToVideoViewModel
[JsonPropertyName("VideoOutput")]
public VideoOutputSettingsCardViewModel VideoOutputSettingsCardViewModel { get; }
[ObservableProperty]
[JsonIgnore]
private string outputUri;
public InferenceImageToVideoViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
@ -86,6 +90,10 @@ public class InferenceImageToVideoViewModel
samplerCard.IsCfgScaleEnabled = true;
samplerCard.IsSamplerSelectionEnabled = true;
samplerCard.IsSchedulerSelectionEnabled = true;
samplerCard.CfgScale = 2.5d;
samplerCard.SelectedSampler = ComfySampler.Euler;
samplerCard.SelectedScheduler = ComfyScheduler.Karras;
samplerCard.IsDenoiseStrengthEnabled = true;
});
BatchSizeCardViewModel = vmFactory.Get<BatchSizeCardViewModel>();
@ -105,6 +113,19 @@ public class InferenceImageToVideoViewModel
);
}
public override void OnLoaded()
{
EventManager.Instance.ImageFileAdded += OnImageFileAdded;
}
private void OnImageFileAdded(object? sender, FilePath e)
{
if (!e.Extension.Contains("gif"))
return;
OutputUri = e;
}
/// <inheritdoc />
protected override void BuildPrompt(BuildPromptEventArgs args)
{
@ -122,7 +143,9 @@ public class InferenceImageToVideoViewModel
ModelCardViewModel.ApplyStep(args);
// Setup latent from image
var imageLoad = builder.Nodes.AddTypedNode(
var imageLoad = builder
.Nodes
.AddTypedNode(
new ComfyNodeBuilder.LoadImage
{
Name = builder.Nodes.GetUniqueName("ControlNet_LoadImage"),
@ -132,9 +155,7 @@ public class InferenceImageToVideoViewModel
}
);
builder.Connections.Primary = imageLoad.Output1;
builder.Connections.PrimarySize =
SelectImageCardViewModel.CurrentBitmapSize
?? new Size(SamplerCardViewModel.Width, SamplerCardViewModel.Height);
builder.Connections.PrimarySize = SelectImageCardViewModel.CurrentBitmapSize;
// Setup img2vid stuff
// Set width & height from SamplerCard
@ -159,10 +180,7 @@ public class InferenceImageToVideoViewModel
}
/// <inheritdoc />
protected override async Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
protected override async Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
{
if (!await CheckClientConnectedWithPrompt() || !ClientManager.IsConnected)
{
@ -184,11 +202,7 @@ public class InferenceImageToVideoViewModel
{
var seed = seedCard.Seed + i;
var buildPromptArgs = new BuildPromptEventArgs
{
Overrides = overrides,
SeedOverride = seed
};
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed };
BuildPrompt(buildPromptArgs);
var generationArgs = new ImageGenerationEventArgs

14
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs

@ -107,9 +107,7 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
// If upscale is enabled, add another upscale group
if (IsUpscaleEnabled)
{
var upscaleSize = builder.Connections.PrimarySize.WithScale(
UpscalerCardViewModel.Scale
);
var upscaleSize = builder.Connections.PrimarySize.WithScale(UpscalerCardViewModel.Scale);
// Build group
builder.Connections.Primary = builder
@ -144,10 +142,7 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
}
/// <inheritdoc />
protected override async Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
protected override async Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
{
if (!ClientManager.IsConnected)
{
@ -174,10 +169,7 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
Client = ClientManager.Client,
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
Parameters = new GenerationParameters
{
ModelName = UpscalerCardViewModel.SelectedUpscaler?.Name,
},
Parameters = new GenerationParameters { ModelName = UpscalerCardViewModel.SelectedUpscaler?.Name, },
Project = InferenceProjectDocument.FromLoadable(this)
};

39
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -26,9 +26,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), IsPersistent = true)]
[ManagedService]
[Transient]
public class InferenceTextToImageViewModel
: InferenceGenerationViewModelBase,
IParametersLoadableState
public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, IParametersLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -162,22 +160,19 @@ public class InferenceTextToImageViewModel
/// <inheritdoc />
protected override IEnumerable<ImageSource> GetInputImages()
{
var samplerImages = SamplerCardViewModel.ModulesCardViewModel.Cards
var samplerImages = SamplerCardViewModel
.ModulesCardViewModel
.Cards
.OfType<IInputImageProvider>()
.SelectMany(m => m.GetInputImages());
var moduleImages = ModulesCardViewModel.Cards
.OfType<IInputImageProvider>()
.SelectMany(m => m.GetInputImages());
var moduleImages = ModulesCardViewModel.Cards.OfType<IInputImageProvider>().SelectMany(m => m.GetInputImages());
return samplerImages.Concat(moduleImages);
}
/// <inheritdoc />
protected override async Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
protected override async Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
{
// Validate the prompts
if (!await PromptCardViewModel.ValidatePrompts())
@ -205,11 +200,7 @@ public class InferenceTextToImageViewModel
{
var seed = seedCard.Seed + i;
var buildPromptArgs = new BuildPromptEventArgs
{
Overrides = overrides,
SeedOverride = seed
};
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed };
BuildPrompt(buildPromptArgs);
var generationArgs = new ImageGenerationEventArgs
@ -270,16 +261,12 @@ public class InferenceTextToImageViewModel
if (state.TryGetPropertyValue("HiresSampler", out var hiresSamplerState))
{
module
.GetCard<SamplerCardViewModel>()
.LoadStateFromJsonObject(hiresSamplerState!.AsObject());
module.GetCard<SamplerCardViewModel>().LoadStateFromJsonObject(hiresSamplerState!.AsObject());
}
if (state.TryGetPropertyValue("HiresUpscaler", out var hiresUpscalerState))
{
module
.GetCard<UpscalerCardViewModel>()
.LoadStateFromJsonObject(hiresUpscalerState!.AsObject());
module.GetCard<UpscalerCardViewModel>().LoadStateFromJsonObject(hiresUpscalerState!.AsObject());
}
});
@ -289,14 +276,14 @@ public class InferenceTextToImageViewModel
if (state.TryGetPropertyValue("Upscaler", out var upscalerState))
{
module
.GetCard<UpscalerCardViewModel>()
.LoadStateFromJsonObject(upscalerState!.AsObject());
module.GetCard<UpscalerCardViewModel>().LoadStateFromJsonObject(upscalerState!.AsObject());
}
});
// Add FreeU to sampler
SamplerCardViewModel.ModulesCardViewModel.AddModule<FreeUModule>(module =>
SamplerCardViewModel
.ModulesCardViewModel
.AddModule<FreeUModule>(module =>
{
module.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsFreeUEnabled");
});

20
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -10,8 +10,8 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -42,6 +42,7 @@ public partial class ModelCardViewModel(IInferenceClientManager clientManager)
private bool disableSettings;
public IInferenceClientManager ClientManager { get; } = clientManager;
/// <inheritdoc />
public virtual void ApplyStep(ModuleApplyStepEventArgs e)
{
@ -50,9 +51,7 @@ public partial class ModelCardViewModel(IInferenceClientManager clientManager)
new ComfyNodeBuilder.CheckpointLoaderSimple
{
Name = "CheckpointLoader",
CkptName =
SelectedModel?.RelativePath
?? throw new ValidationException("Model not selected")
CkptName = SelectedModel?.RelativePath ?? throw new ValidationException("Model not selected")
}
);
@ -85,9 +84,7 @@ public partial class ModelCardViewModel(IInferenceClientManager clientManager)
new ComfyNodeBuilder.VAELoader
{
Name = "VAELoader",
VaeName =
SelectedVae?.RelativePath
?? throw new ValidationException("VAE enabled but not selected")
VaeName = SelectedVae?.RelativePath ?? throw new ValidationException("VAE enabled but not selected")
}
);
@ -147,19 +144,14 @@ public partial class ModelCardViewModel(IInferenceClientManager clientManager)
model = currentModels.FirstOrDefault(
m =>
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256
&& sha256.StartsWith(
parameters.ModelHash,
StringComparison.InvariantCultureIgnoreCase
)
&& sha256.StartsWith(parameters.ModelHash, StringComparison.InvariantCultureIgnoreCase)
);
}
else
{
// Name matches
model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(
m => m.ShortDisplayName.StartsWith(paramsModelName)
);
model ??= currentModels.FirstOrDefault(m => m.ShortDisplayName.StartsWith(paramsModelName));
}
if (model is not null)

20
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ControlNetModule.cs

@ -42,9 +42,8 @@ public class ControlNetModule : ModuleBase
{
Name = e.Nodes.GetUniqueName("ControlNet_LoadImage"),
Image =
card.SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached(
"Inference"
) ?? throw new ValidationException("No ImageSource")
card.SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached("Inference")
?? throw new ValidationException("No ImageSource")
}
);
@ -52,9 +51,7 @@ public class ControlNetModule : ModuleBase
new ComfyNodeBuilder.ControlNetLoader
{
Name = e.Nodes.GetUniqueName("ControlNetLoader"),
ControlNetName =
card.SelectedModel?.FileName
?? throw new ValidationException("No SelectedModel"),
ControlNetName = card.SelectedModel?.FileName ?? throw new ValidationException("No SelectedModel"),
}
);
@ -64,10 +61,8 @@ public class ControlNetModule : ModuleBase
Name = e.Nodes.GetUniqueName("ControlNetApply"),
Image = imageLoad.Output1,
ControlNet = controlNetLoader.Output,
Positive =
e.Temp.Conditioning?.Positive ?? throw new ArgumentException("No Conditioning"),
Negative =
e.Temp.Conditioning?.Negative ?? throw new ArgumentException("No Conditioning"),
Positive = e.Temp.Conditioning?.Positive ?? throw new ArgumentException("No Conditioning"),
Negative = e.Temp.Conditioning?.Negative ?? throw new ArgumentException("No Conditioning"),
Strength = card.Strength,
StartPercent = card.StartPercent,
EndPercent = card.EndPercent,
@ -93,10 +88,7 @@ public class ControlNetModule : ModuleBase
}
);
e.Temp.RefinerConditioning = (
controlNetRefinerApply.Output1,
controlNetRefinerApply.Output2
);
e.Temp.RefinerConditioning = (controlNetRefinerApply.Output1, controlNetRefinerApply.Output2);
}
}
}

86
StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs

@ -54,11 +54,9 @@ public partial class OutputsPageViewModel : PageViewModelBase
private readonly ILogger<OutputsPageViewModel> logger;
public override string Title => Resources.Label_OutputsPageTitle;
public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.Grid, IsFilled = true };
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.Grid, IsFilled = true };
public SourceCache<LocalImageFile, string> OutputsCache { get; } =
new(file => file.AbsolutePath);
public SourceCache<LocalImageFile, string> OutputsCache { get; } = new(file => file.AbsolutePath);
public IObservableCollection<OutputImageViewModel> Outputs { get; set; } =
new ObservableCollectionExtended<OutputImageViewModel>();
@ -88,8 +86,7 @@ public partial class OutputsPageViewModel : PageViewModelBase
[ObservableProperty]
private bool isConsolidating;
public bool CanShowOutputTypes =>
SelectedCategory?.Name?.Equals("Shared Output Folder") ?? false;
public bool CanShowOutputTypes => SelectedCategory?.Name?.Equals("Shared Output Folder") ?? false;
public string NumImagesSelected =>
NumItemsSelected == 1
@ -163,10 +160,7 @@ public partial class OutputsPageViewModel : PageViewModelBase
GetOutputs(path);
}
partial void OnSelectedCategoryChanged(
PackageOutputCategory? oldValue,
PackageOutputCategory? newValue
)
partial void OnSelectedCategoryChanged(PackageOutputCategory? oldValue, PackageOutputCategory? newValue)
{
if (oldValue == newValue || newValue == null)
return;
@ -217,13 +211,11 @@ public partial class OutputsPageViewModel : PageViewModelBase
var vm = new ImageViewerViewModel { ImageSource = image, LocalImageFile = item.ImageFile };
using var onNext = Observable
.FromEventPattern<DirectionalNavigationEventArgs>(
vm,
nameof(ImageViewerViewModel.NavigationRequested)
)
.FromEventPattern<DirectionalNavigationEventArgs>(vm, nameof(ImageViewerViewModel.NavigationRequested))
.Subscribe(ctx =>
{
Dispatcher.UIThread
Dispatcher
.UIThread
.InvokeAsync(async () =>
{
var sender = (ImageViewerViewModel)ctx.Sender!;
@ -232,9 +224,7 @@ public partial class OutputsPageViewModel : PageViewModelBase
if (newIndex >= 0 && newIndex < Outputs.Count)
{
var newImage = Outputs[newIndex];
var newImageSource = new ImageSource(
new FilePath(newImage.ImageFile.AbsolutePath)
);
var newImageSource = new ImageSource(new FilePath(newImage.ImageFile.AbsolutePath));
// Preload
await newImageSource.GetBitmapAsync();
@ -386,7 +376,9 @@ public partial class OutputsPageViewModel : PageViewModelBase
public async Task ConsolidateImages()
{
var stackPanel = new StackPanel();
stackPanel.Children.Add(
stackPanel
.Children
.Add(
new TextBlock
{
Text = Resources.Label_ConsolidateExplanation,
@ -401,7 +393,9 @@ public partial class OutputsPageViewModel : PageViewModelBase
continue;
}
stackPanel.Children.Add(
stackPanel
.Children
.Add(
new CheckBox
{
Content = $"{category.Name} ({category.Path})",
@ -430,25 +424,14 @@ public partial class OutputsPageViewModel : PageViewModelBase
Directory.CreateDirectory(settingsManager.ConsolidatedImagesDirectory);
foreach (
var category in stackPanel.Children.OfType<CheckBox>().Where(c => c.IsChecked == true)
)
foreach (var category in stackPanel.Children.OfType<CheckBox>().Where(c => c.IsChecked == true))
{
if (
string.IsNullOrWhiteSpace(category.Tag?.ToString())
|| !Directory.Exists(category.Tag?.ToString())
)
if (string.IsNullOrWhiteSpace(category.Tag?.ToString()) || !Directory.Exists(category.Tag?.ToString()))
continue;
var directory = category.Tag.ToString();
foreach (
var path in Directory.EnumerateFiles(
directory,
"*.png",
SearchOption.AllDirectories
)
)
foreach (var path in Directory.EnumerateFiles(directory, "*.png", SearchOption.AllDirectories))
{
try
{
@ -499,10 +482,7 @@ public partial class OutputsPageViewModel : PageViewModelBase
if (
!Directory.Exists(directory)
&& (
SelectedCategory.Path != settingsManager.ImagesDirectory
|| SelectedOutputType != SharedOutputType.All
)
&& (SelectedCategory.Path != settingsManager.ImagesDirectory || SelectedOutputType != SharedOutputType.All)
)
{
Directory.CreateDirectory(directory);
@ -534,23 +514,18 @@ public partial class OutputsPageViewModel : PageViewModelBase
var previouslySelectedCategory = SelectedCategory;
var packageCategories = settingsManager.Settings.InstalledPackages
var packageCategories = settingsManager
.Settings
.InstalledPackages
.Where(x => !x.UseSharedOutputFolder)
.Select(packageFactory.GetPackagePair)
.WhereNotNull()
.Where(
p =>
p.BasePackage.SharedOutputFolders != null
&& p.BasePackage.SharedOutputFolders.Any()
)
.Where(p => p.BasePackage.SharedOutputFolders != null && p.BasePackage.SharedOutputFolders.Any())
.Select(
pair =>
new PackageOutputCategory
{
Path = Path.Combine(
pair.InstalledPackage.FullPath!,
pair.BasePackage.OutputFolderName
),
Path = Path.Combine(pair.InstalledPackage.FullPath!, pair.BasePackage.OutputFolderName),
Name = pair.InstalledPackage.DisplayName ?? ""
}
)
@ -558,25 +533,16 @@ public partial class OutputsPageViewModel : PageViewModelBase
packageCategories.Insert(
0,
new PackageOutputCategory
{
Path = settingsManager.ImagesDirectory,
Name = "Shared Output Folder"
}
new PackageOutputCategory { Path = settingsManager.ImagesDirectory, Name = "Shared Output Folder" }
);
packageCategories.Insert(
1,
new PackageOutputCategory
{
Path = settingsManager.ImagesInferenceDirectory,
Name = "Inference"
}
new PackageOutputCategory { Path = settingsManager.ImagesInferenceDirectory, Name = "Inference" }
);
Categories = new ObservableCollection<PackageOutputCategory>(packageCategories);
SelectedCategory =
Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name)
?? Categories.First();
Categories.FirstOrDefault(x => x.Name == previouslySelectedCategory?.Name) ?? Categories.First();
}
}

5
StabilityMatrix.Avalonia/Views/Inference/InferenceImageToVideoView.axaml

@ -10,6 +10,7 @@
xmlns:vmInference="using:StabilityMatrix.Avalonia.ViewModels.Inference"
xmlns:dock="clr-namespace:StabilityMatrix.Avalonia.Controls.Dock"
xmlns:modelsInference="clr-namespace:StabilityMatrix.Avalonia.Models.Inference"
xmlns:gif="clr-namespace:Avalonia.Gif;assembly=Avalonia.Gif"
d:DataContext="{x:Static mocks:DesignData.InferenceImageToVideoViewModel}"
d:DesignHeight="800"
d:DesignWidth="1000"
@ -93,9 +94,9 @@
<Grid
x:CompileBindings="False"
DataContext="{Binding ElementName=Dock, Path=DataContext}">
<controls:ImageGalleryCard
<gif:GifImage
Grid.Row="0"
DataContext="{Binding ImageGalleryCardViewModel}" />
SourceUri="{Binding OutputUri}" />
<StackPanel
DataContext="{Binding OutputProgress}"

4
StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs

@ -58,8 +58,6 @@ public partial class InferencePage : UserControlBase
private void AddTabMenu_ImgToVideo_OnClick(object? sender, RoutedEventArgs e)
{
(DataContext as InferenceViewModel)!.AddTabCommand.Execute(
InferenceProjectType.ImageToVideo
);
(DataContext as InferenceViewModel)!.AddTabCommand.Execute(InferenceProjectType.ImageToVideo);
}
}

14
StabilityMatrix.Core/Animation/GifConverter.cs

@ -0,0 +1,14 @@
using ImageMagick;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Core.Animation;
public class GifConverter
{
public static async Task ConvertWebpToGif(FilePath filePath)
{
using var webp = new MagickImageCollection(filePath, MagickFormat.WebP);
var path = filePath.ToString().Replace(".webp", ".gif");
await webp.WriteAsync(path, MagickFormat.Gif).ConfigureAwait(false);
}
}

1
StabilityMatrix.Core/StabilityMatrix.Core.csproj

@ -31,6 +31,7 @@
<PackageReference Include="JetBrains.Annotations" Version="2023.3.0" />
<PackageReference Include="LiteDB" Version="5.0.17" />
<PackageReference Include="LiteDB.Async" Version="0.1.7" />
<PackageReference Include="Magick.NET-Q8-x64" Version="13.5.0" />
<PackageReference Include="MetadataExtractor" Version="2.8.1" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="8.0.0" />

6
StabilityMatrix.sln

@ -15,6 +15,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StabilityMatrix.Avalonia.Di
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StabilityMatrix.UITests", "StabilityMatrix.UITests\StabilityMatrix.UITests.csproj", "{8C7EDDD1-7FC1-4A15-B379-910A8DA7BCA6}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Avalonia.Gif", "Avalonia.Gif\Avalonia.Gif.csproj", "{72A73F1E-024B-4A25-AD34-626198D9527F}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -43,6 +45,10 @@ Global
{8C7EDDD1-7FC1-4A15-B379-910A8DA7BCA6}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8C7EDDD1-7FC1-4A15-B379-910A8DA7BCA6}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8C7EDDD1-7FC1-4A15-B379-910A8DA7BCA6}.Release|Any CPU.Build.0 = Release|Any CPU
{72A73F1E-024B-4A25-AD34-626198D9527F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{72A73F1E-024B-4A25-AD34-626198D9527F}.Debug|Any CPU.Build.0 = Debug|Any CPU
{72A73F1E-024B-4A25-AD34-626198D9527F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{72A73F1E-024B-4A25-AD34-626198D9527F}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE

Loading…
Cancel
Save