using System; using System.ComponentModel.DataAnnotations; using System.Linq; using System.Text.Json.Nodes; using CommunityToolkit.Mvvm.ComponentModel; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(ModelCard))] [ManagedService] [Transient] public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState, IComfyStep { [ObservableProperty] private HybridModelFile? selectedModel; [ObservableProperty] private bool isRefinerSelectionEnabled; [ObservableProperty] private HybridModelFile? selectedRefiner = HybridModelFile.None; [ObservableProperty] private HybridModelFile? selectedVae = HybridModelFile.Default; [ObservableProperty] private bool isVaeSelectionEnabled; public string? SelectedModelName => SelectedModel?.RelativePath; public string? SelectedVaeName => SelectedVae?.RelativePath; public IInferenceClientManager ClientManager { get; } public ModelCardViewModel(IInferenceClientManager clientManager) { ClientManager = clientManager; } /// public void ApplyStep(ModuleApplyStepEventArgs e) { // Load base checkpoint var baseLoader = e.Nodes.AddTypedNode( new ComfyNodeBuilder.CheckpointLoaderSimple { Name = "CheckpointLoader", CkptName = SelectedModel?.RelativePath ?? throw new ValidationException("Model not selected") } ); e.Builder.Connections.BaseModel = baseLoader.Output1; e.Builder.Connections.BaseClip = baseLoader.Output2; e.Builder.Connections.BaseVAE = baseLoader.Output3; // Load refiner if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false }) { var refinerLoader = e.Nodes.AddTypedNode( new ComfyNodeBuilder.CheckpointLoaderSimple { Name = "Refiner_CheckpointLoader", CkptName = SelectedRefiner?.RelativePath ?? throw new ValidationException("Refiner Model enabled but not selected") } ); e.Builder.Connections.RefinerModel = refinerLoader.Output1; e.Builder.Connections.RefinerClip = refinerLoader.Output2; e.Builder.Connections.RefinerVAE = refinerLoader.Output3; } // Load custom VAE if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false }) { var vaeLoader = e.Nodes.AddTypedNode( new ComfyNodeBuilder.VAELoader { Name = "VAELoader", VaeName = SelectedVae?.RelativePath ?? throw new ValidationException("VAE enabled but not selected") } ); e.Builder.Connections.PrimaryVAE = vaeLoader.Output; } } /// public override JsonObject SaveStateToJsonObject() { return SerializeModel( new ModelCardModel { SelectedModelName = SelectedModelName, SelectedVaeName = SelectedVaeName, IsVaeSelectionEnabled = IsVaeSelectionEnabled } ); } /// public override void LoadStateFromJsonObject(JsonObject state) { var model = DeserializeModel(state); SelectedModel = model.SelectedModelName is null ? null : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); SelectedVae = model.SelectedVaeName is null ? HybridModelFile.Default : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); } /// public void LoadStateFromParameters(GenerationParameters parameters) { if (parameters.ModelName is not { } paramsModelName) return; var currentModels = ClientManager.Models; HybridModelFile? model; // First try hash match if (parameters.ModelHash is not null) { model = currentModels.FirstOrDefault( m => m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256 && sha256.StartsWith( parameters.ModelHash, StringComparison.InvariantCultureIgnoreCase ) ); } else { // Name matches model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName)); model ??= currentModels.FirstOrDefault( m => m.ShortDisplayName.StartsWith(paramsModelName) ); } if (model is not null) { SelectedModel = model; } } /// public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { return parameters with { ModelName = SelectedModel?.FileName, ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256 }; } internal class ModelCardModel { public string? SelectedModelName { get; init; } public string? SelectedVaeName { get; init; } public bool IsVaeSelectionEnabled { get; init; } } }