|
|
|
using System;
|
|
|
|
using System.ComponentModel.DataAnnotations;
|
|
|
|
using System.Linq;
|
|
|
|
using System.Text.Json.Nodes;
|
|
|
|
using CommunityToolkit.Mvvm.ComponentModel;
|
|
|
|
using StabilityMatrix.Avalonia.Controls;
|
|
|
|
using StabilityMatrix.Avalonia.Models;
|
|
|
|
using StabilityMatrix.Avalonia.Models.Inference;
|
|
|
|
using StabilityMatrix.Avalonia.Services;
|
|
|
|
using StabilityMatrix.Avalonia.ViewModels.Base;
|
|
|
|
using StabilityMatrix.Core.Attributes;
|
|
|
|
using StabilityMatrix.Core.Models;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
|
|
|
|
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
|
|
|
|
|
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
|
|
|
|
|
|
|
|
[View(typeof(ModelCard))]
|
|
|
|
[ManagedService]
|
|
|
|
[Transient]
|
|
|
|
public partial class ModelCardViewModel
|
|
|
|
: LoadableViewModelBase,
|
|
|
|
IParametersLoadableState,
|
|
|
|
IComfyStep
|
|
|
|
{
|
|
|
|
[ObservableProperty]
|
|
|
|
private HybridModelFile? selectedModel;
|
|
|
|
|
|
|
|
[ObservableProperty]
|
|
|
|
private bool isRefinerSelectionEnabled;
|
|
|
|
|
|
|
|
[ObservableProperty]
|
|
|
|
private HybridModelFile? selectedRefiner = HybridModelFile.None;
|
|
|
|
|
|
|
|
[ObservableProperty]
|
|
|
|
private HybridModelFile? selectedVae = HybridModelFile.Default;
|
|
|
|
|
|
|
|
[ObservableProperty]
|
|
|
|
private bool isVaeSelectionEnabled;
|
|
|
|
|
|
|
|
public string? SelectedModelName => SelectedModel?.RelativePath;
|
|
|
|
|
|
|
|
public string? SelectedVaeName => SelectedVae?.RelativePath;
|
|
|
|
|
|
|
|
public IInferenceClientManager ClientManager { get; }
|
|
|
|
|
|
|
|
public ModelCardViewModel(IInferenceClientManager clientManager)
|
|
|
|
{
|
|
|
|
ClientManager = clientManager;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public void ApplyStep(ModuleApplyStepEventArgs e)
|
|
|
|
{
|
|
|
|
// Load base checkpoint
|
|
|
|
var baseLoader = e.Nodes.AddTypedNode(
|
|
|
|
new ComfyNodeBuilder.CheckpointLoaderSimple
|
|
|
|
{
|
|
|
|
Name = "CheckpointLoader",
|
|
|
|
CkptName =
|
|
|
|
SelectedModel?.RelativePath
|
|
|
|
?? throw new ValidationException("Model not selected")
|
|
|
|
}
|
|
|
|
);
|
|
|
|
|
|
|
|
e.Builder.Connections.BaseModel = baseLoader.Output1;
|
|
|
|
e.Builder.Connections.BaseClip = baseLoader.Output2;
|
|
|
|
e.Builder.Connections.BaseVAE = baseLoader.Output3;
|
|
|
|
|
|
|
|
// Load refiner
|
|
|
|
if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false })
|
|
|
|
{
|
|
|
|
var refinerLoader = e.Nodes.AddTypedNode(
|
|
|
|
new ComfyNodeBuilder.CheckpointLoaderSimple
|
|
|
|
{
|
|
|
|
Name = "Refiner_CheckpointLoader",
|
|
|
|
CkptName =
|
|
|
|
SelectedRefiner?.RelativePath
|
|
|
|
?? throw new ValidationException("Refiner Model enabled but not selected")
|
|
|
|
}
|
|
|
|
);
|
|
|
|
|
|
|
|
e.Builder.Connections.RefinerModel = refinerLoader.Output1;
|
|
|
|
e.Builder.Connections.RefinerClip = refinerLoader.Output2;
|
|
|
|
e.Builder.Connections.RefinerVAE = refinerLoader.Output3;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Load custom VAE
|
|
|
|
if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false })
|
|
|
|
{
|
|
|
|
var vaeLoader = e.Nodes.AddTypedNode(
|
|
|
|
new ComfyNodeBuilder.VAELoader
|
|
|
|
{
|
|
|
|
Name = "VAELoader",
|
|
|
|
VaeName =
|
|
|
|
SelectedVae?.RelativePath
|
|
|
|
?? throw new ValidationException("VAE enabled but not selected")
|
|
|
|
}
|
|
|
|
);
|
|
|
|
|
|
|
|
e.Builder.Connections.PrimaryVAE = vaeLoader.Output;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public override JsonObject SaveStateToJsonObject()
|
|
|
|
{
|
|
|
|
return SerializeModel(
|
|
|
|
new ModelCardModel
|
|
|
|
{
|
|
|
|
SelectedModelName = SelectedModelName,
|
|
|
|
SelectedVaeName = SelectedVaeName,
|
|
|
|
IsVaeSelectionEnabled = IsVaeSelectionEnabled
|
|
|
|
}
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public override void LoadStateFromJsonObject(JsonObject state)
|
|
|
|
{
|
|
|
|
var model = DeserializeModel<ModelCardModel>(state);
|
|
|
|
|
|
|
|
SelectedModel = model.SelectedModelName is null
|
|
|
|
? null
|
|
|
|
: ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName);
|
|
|
|
|
|
|
|
SelectedVae = model.SelectedVaeName is null
|
|
|
|
? HybridModelFile.Default
|
|
|
|
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public void LoadStateFromParameters(GenerationParameters parameters)
|
|
|
|
{
|
|
|
|
if (parameters.ModelName is not { } paramsModelName)
|
|
|
|
return;
|
|
|
|
|
|
|
|
var currentModels = ClientManager.Models;
|
|
|
|
|
|
|
|
HybridModelFile? model;
|
|
|
|
|
|
|
|
// First try hash match
|
|
|
|
if (parameters.ModelHash is not null)
|
|
|
|
{
|
|
|
|
model = currentModels.FirstOrDefault(
|
|
|
|
m =>
|
|
|
|
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256
|
|
|
|
&& sha256.StartsWith(
|
|
|
|
parameters.ModelHash,
|
|
|
|
StringComparison.InvariantCultureIgnoreCase
|
|
|
|
)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
// Name matches
|
|
|
|
model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName));
|
|
|
|
model ??= currentModels.FirstOrDefault(
|
|
|
|
m => m.ShortDisplayName.StartsWith(paramsModelName)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (model is not null)
|
|
|
|
{
|
|
|
|
SelectedModel = model;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// <inheritdoc />
|
|
|
|
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
|
|
|
|
{
|
|
|
|
return parameters with
|
|
|
|
{
|
|
|
|
ModelName = SelectedModel?.FileName,
|
|
|
|
ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
internal class ModelCardModel
|
|
|
|
{
|
|
|
|
public string? SelectedModelName { get; init; }
|
|
|
|
public string? SelectedVaeName { get; init; }
|
|
|
|
public bool IsVaeSelectionEnabled { get; init; }
|
|
|
|
}
|
|
|
|
}
|