using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; using System.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using CommunityToolkit.Mvvm.ComponentModel; using NLog; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Inference.Video; using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceImageToVideoView), persistent: true)] [ManagedService] [Transient] public partial class InferenceImageToVideoViewModel : InferenceGenerationViewModelBase, IParametersLoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; private readonly IModelIndexService modelIndexService; [JsonIgnore] public StackCardViewModel StackCardViewModel { get; } [JsonPropertyName("Model")] public ImgToVidModelCardViewModel ModelCardViewModel { get; } [JsonPropertyName("Sampler")] public SamplerCardViewModel SamplerCardViewModel { get; } [JsonPropertyName("BatchSize")] public BatchSizeCardViewModel BatchSizeCardViewModel { get; } [JsonPropertyName("Seed")] public SeedCardViewModel SeedCardViewModel { get; } [JsonPropertyName("ImageLoader")] public SelectImageCardViewModel SelectImageCardViewModel { get; } [JsonPropertyName("Conditioning")] public SvdImgToVidConditioningViewModel SvdImgToVidConditioningViewModel { get; } [JsonPropertyName("VideoOutput")] public VideoOutputSettingsCardViewModel VideoOutputSettingsCardViewModel { get; } public InferenceImageToVideoViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, ISettingsManager settingsManager, ServiceManager vmFactory, IModelIndexService modelIndexService ) : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; this.modelIndexService = modelIndexService; // Get sub view models from service manager SeedCardViewModel = vmFactory.Get(); SeedCardViewModel.GenerateNewSeed(); ModelCardViewModel = vmFactory.Get(); SamplerCardViewModel = vmFactory.Get(samplerCard => { samplerCard.IsDimensionsEnabled = true; samplerCard.IsCfgScaleEnabled = true; samplerCard.IsSamplerSelectionEnabled = true; samplerCard.IsSchedulerSelectionEnabled = true; samplerCard.CfgScale = 2.5d; samplerCard.SelectedSampler = ComfySampler.Euler; samplerCard.SelectedScheduler = ComfyScheduler.Karras; samplerCard.IsDenoiseStrengthEnabled = true; samplerCard.DenoiseStrength = 1.0f; }); BatchSizeCardViewModel = vmFactory.Get(); SelectImageCardViewModel = vmFactory.Get(); SvdImgToVidConditioningViewModel = vmFactory.Get(); VideoOutputSettingsCardViewModel = vmFactory.Get(); StackCardViewModel = vmFactory.Get(); StackCardViewModel.AddCards( ModelCardViewModel, SvdImgToVidConditioningViewModel, SamplerCardViewModel, SeedCardViewModel, VideoOutputSettingsCardViewModel, BatchSizeCardViewModel ); } /// protected override void BuildPrompt(BuildPromptEventArgs args) { base.BuildPrompt(args); var builder = args.Builder; builder.Connections.Seed = args.SeedOverride switch { { } seed => Convert.ToUInt64(seed), _ => Convert.ToUInt64(SeedCardViewModel.Seed) }; // Load models ModelCardViewModel.ApplyStep(args); // Setup latent from image var imageLoad = builder.Nodes.AddTypedNode( new ComfyNodeBuilder.LoadImage { Name = builder.Nodes.GetUniqueName("ControlNet_LoadImage"), Image = SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached("Inference") ?? throw new ValidationException() } ); builder.Connections.Primary = imageLoad.Output1; builder.Connections.PrimarySize = SelectImageCardViewModel.CurrentBitmapSize; // Setup img2vid stuff // Set width & height from SamplerCard SvdImgToVidConditioningViewModel.Width = SamplerCardViewModel.Width; SvdImgToVidConditioningViewModel.Height = SamplerCardViewModel.Height; SvdImgToVidConditioningViewModel.ApplyStep(args); // Setup Sampler and Refiner if enabled SamplerCardViewModel.ApplyStep(args); // Animated webp output VideoOutputSettingsCardViewModel.ApplyStep(args); } /// protected override IEnumerable GetInputImages() { if (SelectImageCardViewModel.ImageSource is { } image) { yield return image; } } /// protected override async Task GenerateImageImpl( GenerateOverrides overrides, CancellationToken cancellationToken ) { if (!await CheckClientConnectedWithPrompt() || !ClientManager.IsConnected) { return; } // If enabled, randomize the seed var seedCard = StackCardViewModel.GetCard(); if (overrides is not { UseCurrentSeed: true } && seedCard.IsRandomizeEnabled) { seedCard.GenerateNewSeed(); } var batches = BatchSizeCardViewModel.BatchCount; var batchArgs = new List(); for (var i = 0; i < batches; i++) { var seed = seedCard.Seed + i; var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed }; BuildPrompt(buildPromptArgs); var generationArgs = new ImageGenerationEventArgs { Client = ClientManager.Client, Nodes = buildPromptArgs.Builder.ToNodeDictionary(), OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(), Parameters = SaveStateToParameters(new GenerationParameters()), Project = InferenceProjectDocument.FromLoadable(this), // Only clear output images on the first batch ClearOutputImages = i == 0 }; batchArgs.Add(generationArgs); } // Run batches foreach (var args in batchArgs) { await RunGeneration(args, cancellationToken); } } /// public void LoadStateFromParameters(GenerationParameters parameters) { SamplerCardViewModel.LoadStateFromParameters(parameters); ModelCardViewModel.LoadStateFromParameters(parameters); SvdImgToVidConditioningViewModel.LoadStateFromParameters(parameters); VideoOutputSettingsCardViewModel.LoadStateFromParameters(parameters); SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed); } /// public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { parameters = SamplerCardViewModel.SaveStateToParameters(parameters); parameters = ModelCardViewModel.SaveStateToParameters(parameters); parameters = SvdImgToVidConditioningViewModel.SaveStateToParameters(parameters); parameters = VideoOutputSettingsCardViewModel.SaveStateToParameters(parameters); parameters.Seed = (ulong)SeedCardViewModel.Seed; return parameters; } // Migration for v2 deserialization public override void LoadStateFromJsonObject(JsonObject state, int version) { if (version > 2) { LoadStateFromJsonObject(state); } } }