Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

251 lines
8.8 KiB

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using NLog;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Video;
using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceImageToVideoView), persistent: true)]
[ManagedService]
[Transient]
public partial class InferenceImageToVideoViewModel
: InferenceGenerationViewModelBase,
IParametersLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
private readonly IModelIndexService modelIndexService;
[JsonIgnore]
public StackCardViewModel StackCardViewModel { get; }
[JsonPropertyName("Model")]
public ImgToVidModelCardViewModel ModelCardViewModel { get; }
[JsonPropertyName("Sampler")]
public SamplerCardViewModel SamplerCardViewModel { get; }
[JsonPropertyName("BatchSize")]
public BatchSizeCardViewModel BatchSizeCardViewModel { get; }
[JsonPropertyName("Seed")]
public SeedCardViewModel SeedCardViewModel { get; }
[JsonPropertyName("ImageLoader")]
public SelectImageCardViewModel SelectImageCardViewModel { get; }
[JsonPropertyName("Conditioning")]
public SvdImgToVidConditioningViewModel SvdImgToVidConditioningViewModel { get; }
[JsonPropertyName("VideoOutput")]
public VideoOutputSettingsCardViewModel VideoOutputSettingsCardViewModel { get; }
public InferenceImageToVideoViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory,
IModelIndexService modelIndexService,
RunningPackageService runningPackageService
)
: base(vmFactory, inferenceClientManager, notificationService, settingsManager, runningPackageService)
{
this.notificationService = notificationService;
this.modelIndexService = modelIndexService;
// Get sub view models from service manager
SeedCardViewModel = vmFactory.Get<SeedCardViewModel>();
SeedCardViewModel.GenerateNewSeed();
ModelCardViewModel = vmFactory.Get<ImgToVidModelCardViewModel>();
SamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
samplerCard.IsDimensionsEnabled = true;
samplerCard.IsCfgScaleEnabled = true;
samplerCard.IsSamplerSelectionEnabled = true;
samplerCard.IsSchedulerSelectionEnabled = true;
samplerCard.CfgScale = 2.5d;
samplerCard.SelectedSampler = ComfySampler.Euler;
samplerCard.SelectedScheduler = ComfyScheduler.Karras;
samplerCard.IsDenoiseStrengthEnabled = true;
samplerCard.DenoiseStrength = 1.0f;
});
BatchSizeCardViewModel = vmFactory.Get<BatchSizeCardViewModel>();
SelectImageCardViewModel = vmFactory.Get<SelectImageCardViewModel>();
SvdImgToVidConditioningViewModel = vmFactory.Get<SvdImgToVidConditioningViewModel>();
VideoOutputSettingsCardViewModel = vmFactory.Get<VideoOutputSettingsCardViewModel>();
StackCardViewModel = vmFactory.Get<StackCardViewModel>();
StackCardViewModel.AddCards(
ModelCardViewModel,
SvdImgToVidConditioningViewModel,
SamplerCardViewModel,
SeedCardViewModel,
VideoOutputSettingsCardViewModel,
BatchSizeCardViewModel
);
}
/// <inheritdoc />
protected override void BuildPrompt(BuildPromptEventArgs args)
{
base.BuildPrompt(args);
var builder = args.Builder;
builder.Connections.Seed = args.SeedOverride switch
{
{ } seed => Convert.ToUInt64(seed),
_ => Convert.ToUInt64(SeedCardViewModel.Seed)
};
// Load models
ModelCardViewModel.ApplyStep(args);
// Setup latent from image
var imageLoad = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.LoadImage
{
Name = builder.Nodes.GetUniqueName("ControlNet_LoadImage"),
Image =
SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached("Inference")
?? throw new ValidationException()
}
);
builder.Connections.Primary = imageLoad.Output1;
builder.Connections.PrimarySize = SelectImageCardViewModel.CurrentBitmapSize;
// Setup img2vid stuff
// Set width & height from SamplerCard
SvdImgToVidConditioningViewModel.Width = SamplerCardViewModel.Width;
SvdImgToVidConditioningViewModel.Height = SamplerCardViewModel.Height;
SvdImgToVidConditioningViewModel.ApplyStep(args);
// Setup Sampler and Refiner if enabled
SamplerCardViewModel.ApplyStep(args);
// Animated webp output
VideoOutputSettingsCardViewModel.ApplyStep(args);
}
/// <inheritdoc />
protected override IEnumerable<ImageSource> GetInputImages()
{
if (SelectImageCardViewModel.ImageSource is { } image)
{
yield return image;
}
}
/// <inheritdoc />
protected override async Task GenerateImageImpl(
GenerateOverrides overrides,
CancellationToken cancellationToken
)
{
if (!await CheckClientConnectedWithPrompt() || !ClientManager.IsConnected)
{
return;
}
if (!await ModelCardViewModel.ValidateModel())
return;
// If enabled, randomize the seed
var seedCard = StackCardViewModel.GetCard<SeedCardViewModel>();
if (overrides is not { UseCurrentSeed: true } && seedCard.IsRandomizeEnabled)
{
seedCard.GenerateNewSeed();
}
var batches = BatchSizeCardViewModel.BatchCount;
var batchArgs = new List<ImageGenerationEventArgs>();
for (var i = 0; i < batches; i++)
{
var seed = seedCard.Seed + i;
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed };
BuildPrompt(buildPromptArgs);
var generationArgs = new ImageGenerationEventArgs
{
Client = ClientManager.Client,
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this),
FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch
ClearOutputImages = i == 0
};
batchArgs.Add(generationArgs);
}
// Run batches
foreach (var args in batchArgs)
{
await RunGeneration(args, cancellationToken);
}
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
SamplerCardViewModel.LoadStateFromParameters(parameters);
ModelCardViewModel.LoadStateFromParameters(parameters);
SvdImgToVidConditioningViewModel.LoadStateFromParameters(parameters);
VideoOutputSettingsCardViewModel.LoadStateFromParameters(parameters);
SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed);
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
parameters = SamplerCardViewModel.SaveStateToParameters(parameters);
parameters = ModelCardViewModel.SaveStateToParameters(parameters);
parameters = SvdImgToVidConditioningViewModel.SaveStateToParameters(parameters);
parameters = VideoOutputSettingsCardViewModel.SaveStateToParameters(parameters);
parameters.Seed = (ulong)SeedCardViewModel.Seed;
return parameters;
}
// Migration for v2 deserialization
public override void LoadStateFromJsonObject(JsonObject state, int version)
{
if (version > 2)
{
LoadStateFromJsonObject(state);
}
}
}