|
|
|
using System;
|
|
|
|
using System.Collections.Generic;
|
|
|
|
using System.ComponentModel.DataAnnotations;
|
|
|
|
using System.Drawing;
|
|
|
|
using System.Linq;
|
|
|
|
using System.Text.Json.Nodes;
|
|
|
|
using System.Text.Json.Serialization;
|
|
|
|
using System.Threading;
|
|
|
|
using System.Threading.Tasks;
|
|
|
|
using CommunityToolkit.Mvvm.ComponentModel;
|
|
|
|
using NLog;
|
|
|
|
using StabilityMatrix.Avalonia.Models;
|
|
|
|
using StabilityMatrix.Avalonia.Models.Inference;
|
|
|
|
using StabilityMatrix.Avalonia.Services;
|
|
|
|
using StabilityMatrix.Avalonia.ViewModels.Base;
|
|
|
|
using StabilityMatrix.Avalonia.ViewModels.Inference.Video;
|
|
|
|
using StabilityMatrix.Avalonia.Views.Inference;
|
|
|
|
using StabilityMatrix.Core.Attributes;
|
|
|
|
using StabilityMatrix.Core.Helper;
|
|
|
|
using StabilityMatrix.Core.Models;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
|
|
|
|
using StabilityMatrix.Core.Models.FileInterfaces;
|
|
|
|
using StabilityMatrix.Core.Services;
|
|
|
|
|
|
|
|
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
|
|
|
|
|
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
|
|
|
|
|
|
|
|
[View(typeof(InferenceImageToVideoView), persistent: true)]
|
|
|
|
[ManagedService]
|
|
|
|
[Transient]
|
|
|
|
public partial class InferenceImageToVideoViewModel : InferenceGenerationViewModelBase, IParametersLoadableState
|
|
|
|
{
|
|
|
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
|
|
|
|
|
|
|
|
private readonly INotificationService notificationService;
|
|
|
|
private readonly IModelIndexService modelIndexService;
|
|
|
|
|
|
|
|
[JsonIgnore]
|
|
|
|
public StackCardViewModel StackCardViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("Model")]
|
|
|
|
public ImgToVidModelCardViewModel ModelCardViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("Sampler")]
|
|
|
|
public SamplerCardViewModel SamplerCardViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("BatchSize")]
|
|
|
|
public BatchSizeCardViewModel BatchSizeCardViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("Seed")]
|
|
|
|
public SeedCardViewModel SeedCardViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("ImageLoader")]
|
|
|
|
public SelectImageCardViewModel SelectImageCardViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("Conditioning")]
|
|
|
|
public SvdImgToVidConditioningViewModel SvdImgToVidConditioningViewModel { get; }
|
|
|
|
|
|
|
|
[JsonPropertyName("VideoOutput")]
|
|
|
|
public VideoOutputSettingsCardViewModel VideoOutputSettingsCardViewModel { get; }
|
|
|
|
|
|
|
|
[ObservableProperty]
|
|
|
|
[JsonIgnore]
|
|
|
|
private string outputUri;
|
|
|
|
|
|
|
|
public InferenceImageToVideoViewModel(
|
|
|
|
INotificationService notificationService,
|
|
|
|
IInferenceClientManager inferenceClientManager,
|
|
|
|
ISettingsManager settingsManager,
|
|
|
|
ServiceManager<ViewModelBase> vmFactory,
|
|
|
|
IModelIndexService modelIndexService
|
|
|
|
)
|
|
|
|
: base(vmFactory, inferenceClientManager, notificationService, settingsManager)
|
|
|
|
{
|
|
|
|
this.notificationService = notificationService;
|
|
|
|
this.modelIndexService = modelIndexService;
|
|
|
|
|
|
|
|
// Get sub view models from service manager
|
|
|
|
|
|
|
|
SeedCardViewModel = vmFactory.Get<SeedCardViewModel>();
|
|
|
|
SeedCardViewModel.GenerateNewSeed();
|
|
|
|
|
|
|
|
ModelCardViewModel = vmFactory.Get<ImgToVidModelCardViewModel>();
|
|
|
|
|
|
|
|
SamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard =>
|
|
|
|
{
|
|
|
|
samplerCard.IsDimensionsEnabled = true;
|
|
|
|
samplerCard.IsCfgScaleEnabled = true;
|
|
|
|
samplerCard.IsSamplerSelectionEnabled = true;
|
|
|
|
samplerCard.IsSchedulerSelectionEnabled = true;
|
|
|
|
samplerCard.CfgScale = 2.5d;
|
|
|
|
samplerCard.SelectedSampler = ComfySampler.Euler;
|
|
|
|
samplerCard.SelectedScheduler = ComfyScheduler.Karras;
|
|
|
|
samplerCard.IsDenoiseStrengthEnabled = true;
|
|
|
|
});
|
|
|
|
|
|
|
|
BatchSizeCardViewModel = vmFactory.Get<BatchSizeCardViewModel>();
|
|
|
|
|
|
|
|
SelectImageCardViewModel = vmFactory.Get<SelectImageCardViewModel>();
|
|
|
|
SvdImgToVidConditioningViewModel = vmFactory.Get<SvdImgToVidConditioningViewModel>();
|
|
|
|
VideoOutputSettingsCardViewModel = vmFactory.Get<VideoOutputSettingsCardViewModel>();
|
|
|
|
|
|
|
|
StackCardViewModel = vmFactory.Get<StackCardViewModel>();
|
|
|
|
StackCardViewModel.AddCards(
|
|
|
|
ModelCardViewModel,
|
|
|
|
SvdImgToVidConditioningViewModel,
|
|
|
|
SamplerCardViewModel,
|
|
|
|
SeedCardViewModel,
|
|
|
|
VideoOutputSettingsCardViewModel,
|
|
|
|
BatchSizeCardViewModel
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
public override void OnLoaded()
|
|
|
|
{
|
|
|
|
EventManager.Instance.ImageFileAdded += OnImageFileAdded;
|
|
|
|
}
|
|
|
|
|
|
|
|
private void OnImageFileAdded(object? sender, FilePath e)
|
|
|
|
{
|
|
|
|
if (!e.Extension.Contains("gif"))
|
|
|
|
return;
|
|
|
|
|
|
|
|
OutputUri = e;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
protected override void BuildPrompt(BuildPromptEventArgs args)
|
|
|
|
{
|
|
|
|
base.BuildPrompt(args);
|
|
|
|
|
|
|
|
var builder = args.Builder;
|
|
|
|
|
|
|
|
builder.Connections.Seed = args.SeedOverride switch
|
|
|
|
{
|
|
|
|
{ } seed => Convert.ToUInt64(seed),
|
|
|
|
_ => Convert.ToUInt64(SeedCardViewModel.Seed)
|
|
|
|
};
|
|
|
|
|
|
|
|
// Load models
|
|
|
|
ModelCardViewModel.ApplyStep(args);
|
|
|
|
|
|
|
|
// Setup latent from image
|
|
|
|
var imageLoad = builder
|
|
|
|
.Nodes
|
|
|
|
.AddTypedNode(
|
|
|
|
new ComfyNodeBuilder.LoadImage
|
|
|
|
{
|
|
|
|
Name = builder.Nodes.GetUniqueName("ControlNet_LoadImage"),
|
|
|
|
Image =
|
|
|
|
SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached("Inference")
|
|
|
|
?? throw new ValidationException()
|
|
|
|
}
|
|
|
|
);
|
|
|
|
builder.Connections.Primary = imageLoad.Output1;
|
|
|
|
builder.Connections.PrimarySize = SelectImageCardViewModel.CurrentBitmapSize;
|
|
|
|
|
|
|
|
// Setup img2vid stuff
|
|
|
|
// Set width & height from SamplerCard
|
|
|
|
SvdImgToVidConditioningViewModel.Width = SamplerCardViewModel.Width;
|
|
|
|
SvdImgToVidConditioningViewModel.Height = SamplerCardViewModel.Height;
|
|
|
|
SvdImgToVidConditioningViewModel.ApplyStep(args);
|
|
|
|
|
|
|
|
// Setup Sampler and Refiner if enabled
|
|
|
|
SamplerCardViewModel.ApplyStep(args);
|
|
|
|
|
|
|
|
// Animated webp output
|
|
|
|
VideoOutputSettingsCardViewModel.ApplyStep(args);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
protected override IEnumerable<ImageSource> GetInputImages()
|
|
|
|
{
|
|
|
|
if (SelectImageCardViewModel.ImageSource is { } image)
|
|
|
|
{
|
|
|
|
yield return image;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
protected override async Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
|
|
|
|
{
|
|
|
|
if (!await CheckClientConnectedWithPrompt() || !ClientManager.IsConnected)
|
|
|
|
{
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If enabled, randomize the seed
|
|
|
|
var seedCard = StackCardViewModel.GetCard<SeedCardViewModel>();
|
|
|
|
if (overrides is not { UseCurrentSeed: true } && seedCard.IsRandomizeEnabled)
|
|
|
|
{
|
|
|
|
seedCard.GenerateNewSeed();
|
|
|
|
}
|
|
|
|
|
|
|
|
var batches = BatchSizeCardViewModel.BatchCount;
|
|
|
|
|
|
|
|
var batchArgs = new List<ImageGenerationEventArgs>();
|
|
|
|
|
|
|
|
for (var i = 0; i < batches; i++)
|
|
|
|
{
|
|
|
|
var seed = seedCard.Seed + i;
|
|
|
|
|
|
|
|
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed };
|
|
|
|
BuildPrompt(buildPromptArgs);
|
|
|
|
|
|
|
|
var generationArgs = new ImageGenerationEventArgs
|
|
|
|
{
|
|
|
|
Client = ClientManager.Client,
|
|
|
|
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
|
|
|
|
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
|
|
|
|
Parameters = SaveStateToParameters(new GenerationParameters()),
|
|
|
|
Project = InferenceProjectDocument.FromLoadable(this),
|
|
|
|
// Only clear output images on the first batch
|
|
|
|
ClearOutputImages = i == 0
|
|
|
|
};
|
|
|
|
|
|
|
|
batchArgs.Add(generationArgs);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run batches
|
|
|
|
foreach (var args in batchArgs)
|
|
|
|
{
|
|
|
|
await RunGeneration(args, cancellationToken);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public void LoadStateFromParameters(GenerationParameters parameters)
|
|
|
|
{
|
|
|
|
SamplerCardViewModel.LoadStateFromParameters(parameters);
|
|
|
|
ModelCardViewModel.LoadStateFromParameters(parameters);
|
|
|
|
|
|
|
|
SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
|
|
|
|
{
|
|
|
|
parameters = SamplerCardViewModel.SaveStateToParameters(parameters);
|
|
|
|
parameters = ModelCardViewModel.SaveStateToParameters(parameters);
|
|
|
|
|
|
|
|
parameters.Seed = (ulong)SeedCardViewModel.Seed;
|
|
|
|
|
|
|
|
return parameters;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Migration for v2 deserialization
|
|
|
|
public override void LoadStateFromJsonObject(JsonObject state, int version)
|
|
|
|
{
|
|
|
|
if (version > 2)
|
|
|
|
{
|
|
|
|
LoadStateFromJsonObject(state);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|