You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
248 lines
8.7 KiB
248 lines
8.7 KiB
using System; |
|
using System.Collections.Generic; |
|
using System.ComponentModel.DataAnnotations; |
|
using System.Linq; |
|
using System.Text.Json.Nodes; |
|
using System.Text.Json.Serialization; |
|
using System.Threading; |
|
using System.Threading.Tasks; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using NLog; |
|
using StabilityMatrix.Avalonia.Models; |
|
using StabilityMatrix.Avalonia.Models.Inference; |
|
using StabilityMatrix.Avalonia.Services; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Avalonia.ViewModels.Inference.Video; |
|
using StabilityMatrix.Avalonia.Views.Inference; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Helper; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.Api.Comfy; |
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; |
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
using StabilityMatrix.Core.Services; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Inference; |
|
|
|
[View(typeof(InferenceImageToVideoView), persistent: true)] |
|
[ManagedService] |
|
[Transient] |
|
public partial class InferenceImageToVideoViewModel |
|
: InferenceGenerationViewModelBase, |
|
IParametersLoadableState |
|
{ |
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
|
|
private readonly INotificationService notificationService; |
|
private readonly IModelIndexService modelIndexService; |
|
|
|
[JsonIgnore] |
|
public StackCardViewModel StackCardViewModel { get; } |
|
|
|
[JsonPropertyName("Model")] |
|
public ImgToVidModelCardViewModel ModelCardViewModel { get; } |
|
|
|
[JsonPropertyName("Sampler")] |
|
public SamplerCardViewModel SamplerCardViewModel { get; } |
|
|
|
[JsonPropertyName("BatchSize")] |
|
public BatchSizeCardViewModel BatchSizeCardViewModel { get; } |
|
|
|
[JsonPropertyName("Seed")] |
|
public SeedCardViewModel SeedCardViewModel { get; } |
|
|
|
[JsonPropertyName("ImageLoader")] |
|
public SelectImageCardViewModel SelectImageCardViewModel { get; } |
|
|
|
[JsonPropertyName("Conditioning")] |
|
public SvdImgToVidConditioningViewModel SvdImgToVidConditioningViewModel { get; } |
|
|
|
[JsonPropertyName("VideoOutput")] |
|
public VideoOutputSettingsCardViewModel VideoOutputSettingsCardViewModel { get; } |
|
|
|
public InferenceImageToVideoViewModel( |
|
INotificationService notificationService, |
|
IInferenceClientManager inferenceClientManager, |
|
ISettingsManager settingsManager, |
|
ServiceManager<ViewModelBase> vmFactory, |
|
IModelIndexService modelIndexService |
|
) |
|
: base(vmFactory, inferenceClientManager, notificationService, settingsManager) |
|
{ |
|
this.notificationService = notificationService; |
|
this.modelIndexService = modelIndexService; |
|
|
|
// Get sub view models from service manager |
|
|
|
SeedCardViewModel = vmFactory.Get<SeedCardViewModel>(); |
|
SeedCardViewModel.GenerateNewSeed(); |
|
|
|
ModelCardViewModel = vmFactory.Get<ImgToVidModelCardViewModel>(); |
|
|
|
SamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard => |
|
{ |
|
samplerCard.IsDimensionsEnabled = true; |
|
samplerCard.IsCfgScaleEnabled = true; |
|
samplerCard.IsSamplerSelectionEnabled = true; |
|
samplerCard.IsSchedulerSelectionEnabled = true; |
|
samplerCard.CfgScale = 2.5d; |
|
samplerCard.SelectedSampler = ComfySampler.Euler; |
|
samplerCard.SelectedScheduler = ComfyScheduler.Karras; |
|
samplerCard.IsDenoiseStrengthEnabled = true; |
|
samplerCard.DenoiseStrength = 1.0f; |
|
}); |
|
|
|
BatchSizeCardViewModel = vmFactory.Get<BatchSizeCardViewModel>(); |
|
|
|
SelectImageCardViewModel = vmFactory.Get<SelectImageCardViewModel>(); |
|
SvdImgToVidConditioningViewModel = vmFactory.Get<SvdImgToVidConditioningViewModel>(); |
|
VideoOutputSettingsCardViewModel = vmFactory.Get<VideoOutputSettingsCardViewModel>(); |
|
|
|
StackCardViewModel = vmFactory.Get<StackCardViewModel>(); |
|
StackCardViewModel.AddCards( |
|
ModelCardViewModel, |
|
SvdImgToVidConditioningViewModel, |
|
SamplerCardViewModel, |
|
SeedCardViewModel, |
|
VideoOutputSettingsCardViewModel, |
|
BatchSizeCardViewModel |
|
); |
|
} |
|
|
|
/// <inheritdoc /> |
|
protected override void BuildPrompt(BuildPromptEventArgs args) |
|
{ |
|
base.BuildPrompt(args); |
|
|
|
var builder = args.Builder; |
|
|
|
builder.Connections.Seed = args.SeedOverride switch |
|
{ |
|
{ } seed => Convert.ToUInt64(seed), |
|
_ => Convert.ToUInt64(SeedCardViewModel.Seed) |
|
}; |
|
|
|
// Load models |
|
ModelCardViewModel.ApplyStep(args); |
|
|
|
// Setup latent from image |
|
var imageLoad = builder.Nodes.AddTypedNode( |
|
new ComfyNodeBuilder.LoadImage |
|
{ |
|
Name = builder.Nodes.GetUniqueName("ControlNet_LoadImage"), |
|
Image = |
|
SelectImageCardViewModel.ImageSource?.GetHashGuidFileNameCached("Inference") |
|
?? throw new ValidationException() |
|
} |
|
); |
|
builder.Connections.Primary = imageLoad.Output1; |
|
builder.Connections.PrimarySize = SelectImageCardViewModel.CurrentBitmapSize; |
|
|
|
// Setup img2vid stuff |
|
// Set width & height from SamplerCard |
|
SvdImgToVidConditioningViewModel.Width = SamplerCardViewModel.Width; |
|
SvdImgToVidConditioningViewModel.Height = SamplerCardViewModel.Height; |
|
SvdImgToVidConditioningViewModel.ApplyStep(args); |
|
|
|
// Setup Sampler and Refiner if enabled |
|
SamplerCardViewModel.ApplyStep(args); |
|
|
|
// Animated webp output |
|
VideoOutputSettingsCardViewModel.ApplyStep(args); |
|
} |
|
|
|
/// <inheritdoc /> |
|
protected override IEnumerable<ImageSource> GetInputImages() |
|
{ |
|
if (SelectImageCardViewModel.ImageSource is { } image) |
|
{ |
|
yield return image; |
|
} |
|
} |
|
|
|
/// <inheritdoc /> |
|
protected override async Task GenerateImageImpl( |
|
GenerateOverrides overrides, |
|
CancellationToken cancellationToken |
|
) |
|
{ |
|
if (!await CheckClientConnectedWithPrompt() || !ClientManager.IsConnected) |
|
{ |
|
return; |
|
} |
|
|
|
if (!await ModelCardViewModel.ValidateModel()) |
|
return; |
|
|
|
// If enabled, randomize the seed |
|
var seedCard = StackCardViewModel.GetCard<SeedCardViewModel>(); |
|
if (overrides is not { UseCurrentSeed: true } && seedCard.IsRandomizeEnabled) |
|
{ |
|
seedCard.GenerateNewSeed(); |
|
} |
|
|
|
var batches = BatchSizeCardViewModel.BatchCount; |
|
|
|
var batchArgs = new List<ImageGenerationEventArgs>(); |
|
|
|
for (var i = 0; i < batches; i++) |
|
{ |
|
var seed = seedCard.Seed + i; |
|
|
|
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed }; |
|
BuildPrompt(buildPromptArgs); |
|
|
|
var generationArgs = new ImageGenerationEventArgs |
|
{ |
|
Client = ClientManager.Client, |
|
Nodes = buildPromptArgs.Builder.ToNodeDictionary(), |
|
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(), |
|
Parameters = SaveStateToParameters(new GenerationParameters()), |
|
Project = InferenceProjectDocument.FromLoadable(this), |
|
// Only clear output images on the first batch |
|
ClearOutputImages = i == 0 |
|
}; |
|
|
|
batchArgs.Add(generationArgs); |
|
} |
|
|
|
// Run batches |
|
foreach (var args in batchArgs) |
|
{ |
|
await RunGeneration(args, cancellationToken); |
|
} |
|
} |
|
|
|
/// <inheritdoc /> |
|
public void LoadStateFromParameters(GenerationParameters parameters) |
|
{ |
|
SamplerCardViewModel.LoadStateFromParameters(parameters); |
|
ModelCardViewModel.LoadStateFromParameters(parameters); |
|
SvdImgToVidConditioningViewModel.LoadStateFromParameters(parameters); |
|
VideoOutputSettingsCardViewModel.LoadStateFromParameters(parameters); |
|
|
|
SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed); |
|
} |
|
|
|
/// <inheritdoc /> |
|
public GenerationParameters SaveStateToParameters(GenerationParameters parameters) |
|
{ |
|
parameters = SamplerCardViewModel.SaveStateToParameters(parameters); |
|
parameters = ModelCardViewModel.SaveStateToParameters(parameters); |
|
parameters = SvdImgToVidConditioningViewModel.SaveStateToParameters(parameters); |
|
parameters = VideoOutputSettingsCardViewModel.SaveStateToParameters(parameters); |
|
|
|
parameters.Seed = (ulong)SeedCardViewModel.Seed; |
|
|
|
return parameters; |
|
} |
|
|
|
// Migration for v2 deserialization |
|
public override void LoadStateFromJsonObject(JsonObject state, int version) |
|
{ |
|
if (version > 2) |
|
{ |
|
LoadStateFromJsonObject(state); |
|
} |
|
} |
|
}
|
|
|