Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
5.8 KiB

using System;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Text.Json.Nodes;
using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ModelCard))]
[ManagedService]
[Transient]
public partial class ModelCardViewModel
: LoadableViewModelBase,
IParametersLoadableState,
IComfyStep
{
[ObservableProperty]
private HybridModelFile? selectedModel;
[ObservableProperty]
private bool isRefinerSelectionEnabled;
[ObservableProperty]
private HybridModelFile? selectedRefiner = HybridModelFile.None;
[ObservableProperty]
private HybridModelFile? selectedVae = HybridModelFile.Default;
[ObservableProperty]
private bool isVaeSelectionEnabled;
public string? SelectedModelName => SelectedModel?.RelativePath;
public string? SelectedVaeName => SelectedVae?.RelativePath;
public IInferenceClientManager ClientManager { get; }
public ModelCardViewModel(IInferenceClientManager clientManager)
{
ClientManager = clientManager;
}
/// <inheritdoc />
public void ApplyStep(ModuleApplyStepEventArgs e)
{
// Load base checkpoint
var baseLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CheckpointLoaderSimple
{
Name = "CheckpointLoader",
CkptName =
SelectedModel?.RelativePath
?? throw new ValidationException("Model not selected")
}
);
e.Builder.Connections.BaseModel = baseLoader.Output1;
e.Builder.Connections.BaseClip = baseLoader.Output2;
e.Builder.Connections.BaseVAE = baseLoader.Output3;
// Load refiner
if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false })
{
var refinerLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CheckpointLoaderSimple
{
Name = "Refiner_CheckpointLoader",
CkptName =
SelectedRefiner?.RelativePath
?? throw new ValidationException("Refiner Model enabled but not selected")
}
);
e.Builder.Connections.RefinerModel = refinerLoader.Output1;
e.Builder.Connections.RefinerClip = refinerLoader.Output2;
e.Builder.Connections.RefinerVAE = refinerLoader.Output3;
}
// Load custom VAE
if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false })
{
var vaeLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.VAELoader
{
Name = "VAELoader",
VaeName =
SelectedVae?.RelativePath
?? throw new ValidationException("VAE enabled but not selected")
}
);
e.Builder.Connections.PrimaryVAE = vaeLoader.Output;
}
}
/// <inheritdoc />
public override JsonObject SaveStateToJsonObject()
{
return SerializeModel(
new ModelCardModel
{
SelectedModelName = SelectedModelName,
SelectedVaeName = SelectedVaeName,
IsVaeSelectionEnabled = IsVaeSelectionEnabled
}
);
}
/// <inheritdoc />
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel<ModelCardModel>(state);
SelectedModel = model.SelectedModelName is null
? null
: ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName);
SelectedVae = model.SelectedVaeName is null
? HybridModelFile.Default
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
}
internal class ModelCardModel
{
public string? SelectedModelName { get; init; }
public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; }
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
if (parameters.ModelName is not { } paramsModelName)
return;
var currentModels = ClientManager.Models;
HybridModelFile? model;
// First try hash match
if (parameters.ModelHash is not null)
{
model = currentModels.FirstOrDefault(
m =>
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256
&& sha256.StartsWith(
parameters.ModelHash,
StringComparison.InvariantCultureIgnoreCase
)
);
}
else
{
// Name matches
model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(
m => m.ShortDisplayName.StartsWith(paramsModelName)
);
}
if (model is not null)
{
SelectedModel = model;
}
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with
{
ModelName = SelectedModel?.FileName,
ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256
};
}
}