diff --git a/StabilityMatrix.Avalonia/Models/Inference/IInputImageProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/IInputImageProvider.cs new file mode 100644 index 00000000..a343a7f7 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/IInputImageProvider.cs @@ -0,0 +1,8 @@ +using System.Collections.Generic; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public interface IInputImageProvider +{ + IEnumerable GetInputImages(); +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 072db3d4..a408fff5 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -1,7 +1,5 @@ using System; using System.Collections.Generic; -using System.ComponentModel.DataAnnotations; -using System.Drawing; using System.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; @@ -10,7 +8,6 @@ using System.Threading.Tasks; using DynamicData.Binding; using NLog; using StabilityMatrix.Avalonia.Extensions; -using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; @@ -18,11 +15,7 @@ using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; -using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; -using StabilityMatrix.Core.Models.Api.Comfy; -using StabilityMatrix.Core.Models.Api.Comfy.Nodes; -using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; using StabilityMatrix.Core.Services; using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView; @@ -30,7 +23,7 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere namespace StabilityMatrix.Avalonia.ViewModels.Inference; -[View(typeof(InferenceTextToImageView), persistent: true)] +[View(typeof(InferenceTextToImageView), IsPersistent = true)] [ManagedService] [Transient] public class InferenceTextToImageViewModel @@ -57,46 +50,12 @@ public class InferenceTextToImageViewModel [JsonPropertyName("Prompt")] public PromptCardViewModel PromptCardViewModel { get; } - [JsonPropertyName("Upscaler")] - public UpscalerCardViewModel UpscalerCardViewModel { get; } - - [JsonPropertyName("HiresSampler")] - public SamplerCardViewModel HiresSamplerCardViewModel { get; } - - [JsonPropertyName("HiresUpscaler")] - public UpscalerCardViewModel HiresUpscalerCardViewModel { get; } - - [JsonPropertyName("FreeU")] - public FreeUCardViewModel FreeUCardViewModel { get; } - [JsonPropertyName("BatchSize")] public BatchSizeCardViewModel BatchSizeCardViewModel { get; } [JsonPropertyName("Seed")] public SeedCardViewModel SeedCardViewModel { get; } - public bool IsFreeUEnabled => false; - public bool IsHiresFixEnabled => false; - public bool IsUpscaleEnabled => false; - - /*public bool IsFreeUEnabled - { - get => StackCardViewModel.GetCard().IsEnabled; - set => StackCardViewModel.GetCard().IsEnabled = value; - } - - public bool IsHiresFixEnabled - { - get => StackCardViewModel.GetCard(1).IsEnabled; - set => StackCardViewModel.GetCard(1).IsEnabled = value; - } - - public bool IsUpscaleEnabled - { - get => StackCardViewModel.GetCard(2).IsEnabled; - set => StackCardViewModel.GetCard(2).IsEnabled = value; - }*/ - public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, @@ -125,13 +84,6 @@ public class InferenceTextToImageViewModel }); PromptCardViewModel = vmFactory.Get(); - HiresSamplerCardViewModel = vmFactory.Get(samplerCard => - { - samplerCard.IsDenoiseStrengthEnabled = true; - }); - HiresUpscalerCardViewModel = vmFactory.Get(); - UpscalerCardViewModel = vmFactory.Get(); - FreeUCardViewModel = vmFactory.Get(); BatchSizeCardViewModel = vmFactory.Get(); ModulesCardViewModel = vmFactory.Get(modulesCard => @@ -207,10 +159,15 @@ public class InferenceTextToImageViewModel /// protected override IEnumerable GetInputImages() { - // TODO support hires in some generic way - return SamplerCardViewModel.ModulesCardViewModel.Cards - .OfType() + var samplerImages = SamplerCardViewModel.ModulesCardViewModel.Cards + .OfType() + .SelectMany(m => m.GetInputImages()); + + var moduleImages = ModulesCardViewModel.Cards + .OfType() .SelectMany(m => m.GetInputImages()); + + return samplerImages.Concat(moduleImages); } /// @@ -295,21 +252,53 @@ public class InferenceTextToImageViewModel return parameters; } - // Migration for v2 deserialization + // Deserialization overrides public override void LoadStateFromJsonObject(JsonObject state, int version) { - if (version > 2) + // For v2 and below, do migration + if (version <= 2) { - LoadStateFromJsonObject(state); - } + ModulesCardViewModel.Clear(); - ModulesCardViewModel.Clear(); + // Add by default the original cards as steps - HiresFix, Upscaler + ModulesCardViewModel.AddModule(module => + { + module.IsEnabled = state.GetPropertyValueOrDefault("IsHiresFixEnabled"); + + if (state.TryGetPropertyValue("HiresSampler", out var hiresSamplerState)) + { + module + .GetCard() + .LoadStateFromJsonObject(hiresSamplerState!.AsObject()); + } + + if (state.TryGetPropertyValue("HiresUpscaler", out var hiresUpscalerState)) + { + module + .GetCard() + .LoadStateFromJsonObject(hiresUpscalerState!.AsObject()); + } + }); - // Add by default the original cards - FreeU, HiresFix, Upscaler - var hiresFix = ModulesCardViewModel.AddModule(); - var upscaler = ModulesCardViewModel.AddModule(); + ModulesCardViewModel.AddModule(module => + { + module.IsEnabled = state.GetPropertyValueOrDefault("IsUpscaleEnabled"); + + if (state.TryGetPropertyValue("Upscaler", out var upscalerState)) + { + module + .GetCard() + .LoadStateFromJsonObject(upscalerState!.AsObject()); + } + }); + + // Add FreeU to sampler + SamplerCardViewModel.ModulesCardViewModel.AddModule(module => + { + module.IsEnabled = state.GetPropertyValueOrDefault("IsFreeUEnabled"); + }); + } - hiresFix.IsEnabled = state.GetPropertyValueOrDefault("IsHiresFixEnabled"); - upscaler.IsEnabled = state.GetPropertyValueOrDefault("IsUpscaleEnabled"); + base.LoadStateFromJsonObject(state); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ModuleBase.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ModuleBase.cs index 8172ec48..a8436b63 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ModuleBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ModuleBase.cs @@ -1,10 +1,12 @@ -using StabilityMatrix.Avalonia.Models.Inference; +using System.Collections.Generic; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; -public abstract class ModuleBase : StackExpanderViewModel, IComfyStep +public abstract class ModuleBase : StackExpanderViewModel, IComfyStep, IInputImageProvider { /// protected ModuleBase(ServiceManager vmFactory) @@ -27,4 +29,12 @@ public abstract class ModuleBase : StackExpanderViewModel, IComfyStep } protected abstract void OnApplyStep(ModuleApplyStepEventArgs e); + + /// + IEnumerable IInputImageProvider.GetInputImages() => GetInputImages(); + + protected virtual IEnumerable GetInputImages() + { + yield break; + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs index 79267e11..7f2a1d35 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs @@ -88,6 +88,14 @@ public partial class StackEditableCardViewModel : StackViewModelBase return card; } + public T AddModule(Action initializer) + where T : ModuleBase + { + var card = vmFactory.Get(initializer); + AddCards(card); + return card; + } + [RelayCommand] private void AddModule(Type type) {