Browse Source

Add generic module image providers

pull/333/head
Ionite 12 months ago
parent
commit
5b7c2c1f1e
No known key found for this signature in database
  1. 8
      StabilityMatrix.Avalonia/Models/Inference/IInputImageProvider.cs
  2. 113
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  3. 14
      StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ModuleBase.cs
  4. 8
      StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs

8
StabilityMatrix.Avalonia/Models/Inference/IInputImageProvider.cs

@ -0,0 +1,8 @@
using System.Collections.Generic;
namespace StabilityMatrix.Avalonia.Models.Inference;
public interface IInputImageProvider
{
IEnumerable<ImageSource> GetInputImages();
}

113
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -1,7 +1,5 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Drawing;
using System.Linq;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
@ -10,7 +8,6 @@ using System.Threading.Tasks;
using DynamicData.Binding;
using NLog;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
@ -18,11 +15,7 @@ using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
using StabilityMatrix.Core.Services;
using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView;
@ -30,7 +23,7 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), persistent: true)]
[View(typeof(InferenceTextToImageView), IsPersistent = true)]
[ManagedService]
[Transient]
public class InferenceTextToImageViewModel
@ -57,46 +50,12 @@ public class InferenceTextToImageViewModel
[JsonPropertyName("Prompt")]
public PromptCardViewModel PromptCardViewModel { get; }
[JsonPropertyName("Upscaler")]
public UpscalerCardViewModel UpscalerCardViewModel { get; }
[JsonPropertyName("HiresSampler")]
public SamplerCardViewModel HiresSamplerCardViewModel { get; }
[JsonPropertyName("HiresUpscaler")]
public UpscalerCardViewModel HiresUpscalerCardViewModel { get; }
[JsonPropertyName("FreeU")]
public FreeUCardViewModel FreeUCardViewModel { get; }
[JsonPropertyName("BatchSize")]
public BatchSizeCardViewModel BatchSizeCardViewModel { get; }
[JsonPropertyName("Seed")]
public SeedCardViewModel SeedCardViewModel { get; }
public bool IsFreeUEnabled => false;
public bool IsHiresFixEnabled => false;
public bool IsUpscaleEnabled => false;
/*public bool IsFreeUEnabled
{
get => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled;
set => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled = value;
}
public bool IsHiresFixEnabled
{
get => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled;
set => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled = value;
}
public bool IsUpscaleEnabled
{
get => StackCardViewModel.GetCard<StackExpanderViewModel>(2).IsEnabled;
set => StackCardViewModel.GetCard<StackExpanderViewModel>(2).IsEnabled = value;
}*/
public InferenceTextToImageViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
@ -125,13 +84,6 @@ public class InferenceTextToImageViewModel
});
PromptCardViewModel = vmFactory.Get<PromptCardViewModel>();
HiresSamplerCardViewModel = vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
samplerCard.IsDenoiseStrengthEnabled = true;
});
HiresUpscalerCardViewModel = vmFactory.Get<UpscalerCardViewModel>();
UpscalerCardViewModel = vmFactory.Get<UpscalerCardViewModel>();
FreeUCardViewModel = vmFactory.Get<FreeUCardViewModel>();
BatchSizeCardViewModel = vmFactory.Get<BatchSizeCardViewModel>();
ModulesCardViewModel = vmFactory.Get<StackEditableCardViewModel>(modulesCard =>
@ -207,10 +159,15 @@ public class InferenceTextToImageViewModel
/// <inheritdoc />
protected override IEnumerable<ImageSource> GetInputImages()
{
// TODO support hires in some generic way
return SamplerCardViewModel.ModulesCardViewModel.Cards
.OfType<ControlNetModule>()
var samplerImages = SamplerCardViewModel.ModulesCardViewModel.Cards
.OfType<IInputImageProvider>()
.SelectMany(m => m.GetInputImages());
var moduleImages = ModulesCardViewModel.Cards
.OfType<IInputImageProvider>()
.SelectMany(m => m.GetInputImages());
return samplerImages.Concat(moduleImages);
}
/// <inheritdoc />
@ -295,21 +252,53 @@ public class InferenceTextToImageViewModel
return parameters;
}
// Migration for v2 deserialization
// Deserialization overrides
public override void LoadStateFromJsonObject(JsonObject state, int version)
{
if (version > 2)
// For v2 and below, do migration
if (version <= 2)
{
LoadStateFromJsonObject(state);
}
ModulesCardViewModel.Clear();
ModulesCardViewModel.Clear();
// Add by default the original cards as steps - HiresFix, Upscaler
ModulesCardViewModel.AddModule<HiresFixModule>(module =>
{
module.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsHiresFixEnabled");
if (state.TryGetPropertyValue("HiresSampler", out var hiresSamplerState))
{
module
.GetCard<SamplerCardViewModel>()
.LoadStateFromJsonObject(hiresSamplerState!.AsObject());
}
if (state.TryGetPropertyValue("HiresUpscaler", out var hiresUpscalerState))
{
module
.GetCard<UpscalerCardViewModel>()
.LoadStateFromJsonObject(hiresUpscalerState!.AsObject());
}
});
// Add by default the original cards - FreeU, HiresFix, Upscaler
var hiresFix = ModulesCardViewModel.AddModule<HiresFixModule>();
var upscaler = ModulesCardViewModel.AddModule<UpscalerModule>();
ModulesCardViewModel.AddModule<UpscalerModule>(module =>
{
module.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsUpscaleEnabled");
if (state.TryGetPropertyValue("Upscaler", out var upscalerState))
{
module
.GetCard<UpscalerCardViewModel>()
.LoadStateFromJsonObject(upscalerState!.AsObject());
}
});
// Add FreeU to sampler
SamplerCardViewModel.ModulesCardViewModel.AddModule<FreeUModule>(module =>
{
module.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsFreeUEnabled");
});
}
hiresFix.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsHiresFixEnabled");
upscaler.IsEnabled = state.GetPropertyValueOrDefault<bool>("IsUpscaleEnabled");
base.LoadStateFromJsonObject(state);
}
}

14
StabilityMatrix.Avalonia/ViewModels/Inference/Modules/ModuleBase.cs

@ -1,10 +1,12 @@
using StabilityMatrix.Avalonia.Models.Inference;
using System.Collections.Generic;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
public abstract class ModuleBase : StackExpanderViewModel, IComfyStep
public abstract class ModuleBase : StackExpanderViewModel, IComfyStep, IInputImageProvider
{
/// <inheritdoc />
protected ModuleBase(ServiceManager<ViewModelBase> vmFactory)
@ -27,4 +29,12 @@ public abstract class ModuleBase : StackExpanderViewModel, IComfyStep
}
protected abstract void OnApplyStep(ModuleApplyStepEventArgs e);
/// <inheritdoc />
IEnumerable<ImageSource> IInputImageProvider.GetInputImages() => GetInputImages();
protected virtual IEnumerable<ImageSource> GetInputImages()
{
yield break;
}
}

8
StabilityMatrix.Avalonia/ViewModels/Inference/StackEditableCardViewModel.cs

@ -88,6 +88,14 @@ public partial class StackEditableCardViewModel : StackViewModelBase
return card;
}
public T AddModule<T>(Action<T> initializer)
where T : ModuleBase
{
var card = vmFactory.Get(initializer);
AddCards(card);
return card;
}
[RelayCommand]
private void AddModule(Type type)
{

Loading…
Cancel
Save