You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
5.9 KiB
188 lines
5.9 KiB
using System; |
|
using System.ComponentModel.DataAnnotations; |
|
using System.Linq; |
|
using System.Text.Json.Nodes; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using StabilityMatrix.Avalonia.Controls; |
|
using StabilityMatrix.Avalonia.Models; |
|
using StabilityMatrix.Avalonia.Models.Inference; |
|
using StabilityMatrix.Avalonia.Services; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; |
|
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Inference; |
|
|
|
[View(typeof(ModelCard))] |
|
[ManagedService] |
|
[Transient] |
|
public partial class ModelCardViewModel |
|
: LoadableViewModelBase, |
|
IParametersLoadableState, |
|
IComfyStep |
|
{ |
|
[ObservableProperty] |
|
private HybridModelFile? selectedModel; |
|
|
|
[ObservableProperty] |
|
private bool isRefinerSelectionEnabled; |
|
|
|
[ObservableProperty] |
|
private HybridModelFile? selectedRefiner = HybridModelFile.None; |
|
|
|
[ObservableProperty] |
|
private HybridModelFile? selectedVae = HybridModelFile.Default; |
|
|
|
[ObservableProperty] |
|
private bool isVaeSelectionEnabled; |
|
|
|
[ObservableProperty] |
|
private bool disableSettings; |
|
|
|
public string? SelectedModelName => SelectedModel?.RelativePath; |
|
|
|
public string? SelectedVaeName => SelectedVae?.RelativePath; |
|
|
|
public IInferenceClientManager ClientManager { get; } |
|
|
|
public ModelCardViewModel(IInferenceClientManager clientManager) |
|
{ |
|
ClientManager = clientManager; |
|
} |
|
|
|
/// <inheritdoc /> |
|
public virtual void ApplyStep(ModuleApplyStepEventArgs e) |
|
{ |
|
// Load base checkpoint |
|
var baseLoader = e.Nodes.AddTypedNode( |
|
new ComfyNodeBuilder.CheckpointLoaderSimple |
|
{ |
|
Name = "CheckpointLoader", |
|
CkptName = |
|
SelectedModel?.RelativePath |
|
?? throw new ValidationException("Model not selected") |
|
} |
|
); |
|
|
|
e.Builder.Connections.BaseModel = baseLoader.Output1; |
|
e.Builder.Connections.BaseClip = baseLoader.Output2; |
|
e.Builder.Connections.BaseVAE = baseLoader.Output3; |
|
|
|
// Load refiner |
|
if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false }) |
|
{ |
|
var refinerLoader = e.Nodes.AddTypedNode( |
|
new ComfyNodeBuilder.CheckpointLoaderSimple |
|
{ |
|
Name = "Refiner_CheckpointLoader", |
|
CkptName = |
|
SelectedRefiner?.RelativePath |
|
?? throw new ValidationException("Refiner Model enabled but not selected") |
|
} |
|
); |
|
|
|
e.Builder.Connections.RefinerModel = refinerLoader.Output1; |
|
e.Builder.Connections.RefinerClip = refinerLoader.Output2; |
|
e.Builder.Connections.RefinerVAE = refinerLoader.Output3; |
|
} |
|
|
|
// Load custom VAE |
|
if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false }) |
|
{ |
|
var vaeLoader = e.Nodes.AddTypedNode( |
|
new ComfyNodeBuilder.VAELoader |
|
{ |
|
Name = "VAELoader", |
|
VaeName = |
|
SelectedVae?.RelativePath |
|
?? throw new ValidationException("VAE enabled but not selected") |
|
} |
|
); |
|
|
|
e.Builder.Connections.PrimaryVAE = vaeLoader.Output; |
|
} |
|
} |
|
|
|
/// <inheritdoc /> |
|
public override JsonObject SaveStateToJsonObject() |
|
{ |
|
return SerializeModel( |
|
new ModelCardModel |
|
{ |
|
SelectedModelName = SelectedModelName, |
|
SelectedVaeName = SelectedVaeName, |
|
IsVaeSelectionEnabled = IsVaeSelectionEnabled |
|
} |
|
); |
|
} |
|
|
|
/// <inheritdoc /> |
|
public override void LoadStateFromJsonObject(JsonObject state) |
|
{ |
|
var model = DeserializeModel<ModelCardModel>(state); |
|
|
|
SelectedModel = model.SelectedModelName is null |
|
? null |
|
: ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); |
|
|
|
SelectedVae = model.SelectedVaeName is null |
|
? HybridModelFile.Default |
|
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); |
|
} |
|
|
|
/// <inheritdoc /> |
|
public void LoadStateFromParameters(GenerationParameters parameters) |
|
{ |
|
if (parameters.ModelName is not { } paramsModelName) |
|
return; |
|
|
|
var currentModels = ClientManager.Models; |
|
|
|
HybridModelFile? model; |
|
|
|
// First try hash match |
|
if (parameters.ModelHash is not null) |
|
{ |
|
model = currentModels.FirstOrDefault( |
|
m => |
|
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256 |
|
&& sha256.StartsWith( |
|
parameters.ModelHash, |
|
StringComparison.InvariantCultureIgnoreCase |
|
) |
|
); |
|
} |
|
else |
|
{ |
|
// Name matches |
|
model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName)); |
|
model ??= currentModels.FirstOrDefault( |
|
m => m.ShortDisplayName.StartsWith(paramsModelName) |
|
); |
|
} |
|
|
|
if (model is not null) |
|
{ |
|
SelectedModel = model; |
|
} |
|
} |
|
|
|
/// <inheritdoc /> |
|
public GenerationParameters SaveStateToParameters(GenerationParameters parameters) |
|
{ |
|
return parameters with |
|
{ |
|
ModelName = SelectedModel?.FileName, |
|
ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256 |
|
}; |
|
} |
|
|
|
internal class ModelCardModel |
|
{ |
|
public string? SelectedModelName { get; init; } |
|
public string? SelectedVaeName { get; init; } |
|
public bool IsVaeSelectionEnabled { get; init; } |
|
} |
|
}
|
|
|