Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

284 lines
9.5 KiB

using System;
using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ModelCard))]
[ManagedService]
[Transient]
public partial class ModelCardViewModel(IInferenceClientManager clientManager)
: LoadableViewModelBase,
IParametersLoadableState,
IComfyStep
{
[ObservableProperty]
private HybridModelFile? selectedModel;
[ObservableProperty]
private bool isRefinerSelectionEnabled;
[ObservableProperty]
private HybridModelFile? selectedRefiner = HybridModelFile.None;
[ObservableProperty]
private HybridModelFile? selectedVae = HybridModelFile.Default;
[ObservableProperty]
private bool isVaeSelectionEnabled;
[ObservableProperty]
private bool disableSettings;
[ObservableProperty]
private bool isClipSkipEnabled;
[NotifyDataErrorInfo]
[ObservableProperty]
[Range(1, 24)]
private int clipSkip = 1;
public IInferenceClientManager ClientManager { get; } = clientManager;
[RelayCommand]
private static async Task OnConfigClickAsync()
{
await DialogHelper
.CreateMarkdownDialog(
"""
You can use a config (.yaml) file to load a model with specific settings.
Place the config file next to the model file with the same name:
```md
Models/
StableDiffusion/
my_model.safetensors
my_model.yaml <-
```
""",
"Using Model Configs",
TextEditorPreset.Console
)
.ShowAsync();
}
public async Task<bool> ValidateModel()
{
if (SelectedModel != null)
return true;
var dialog = DialogHelper.CreateMarkdownDialog(
"Please select a model to continue.",
"No Model Selected"
);
await dialog.ShowAsync();
return false;
}
private static ComfyTypedNodeBase<
ModelNodeConnection,
ClipNodeConnection,
VAENodeConnection
> GetModelLoader(ModuleApplyStepEventArgs e, string nodeName, HybridModelFile model)
{
// Check if config
if (model.Local?.ConfigFullPath is { } configPath)
{
// We'll need to upload the config file to `models/configs` later
var uploadConfigPath = e.AddFileTransferToConfigs(configPath);
return new ComfyNodeBuilder.CheckpointLoader
{
Name = nodeName,
// Only the file name is needed
ConfigName = Path.GetFileName(uploadConfigPath),
CkptName = model.RelativePath
};
}
// Simple loader if no config
return new ComfyNodeBuilder.CheckpointLoaderSimple { Name = nodeName, CkptName = model.RelativePath };
}
/// <inheritdoc />
public virtual void ApplyStep(ModuleApplyStepEventArgs e)
{
// Load base checkpoint
var baseLoader = e.Nodes.AddTypedNode(
GetModelLoader(
e,
"CheckpointLoader_Base",
SelectedModel ?? throw new ValidationException("Model not selected")
)
);
e.Builder.Connections.Base.Model = baseLoader.Output1;
e.Builder.Connections.Base.Clip = baseLoader.Output2;
e.Builder.Connections.Base.VAE = baseLoader.Output3;
// Load refiner if enabled
if (IsRefinerSelectionEnabled && SelectedRefiner is { IsNone: false })
{
var refinerLoader = e.Nodes.AddTypedNode(
GetModelLoader(
e,
"CheckpointLoader_Refiner",
SelectedRefiner ?? throw new ValidationException("Refiner Model enabled but not selected")
)
);
e.Builder.Connections.Refiner.Model = refinerLoader.Output1;
e.Builder.Connections.Refiner.Clip = refinerLoader.Output2;
e.Builder.Connections.Refiner.VAE = refinerLoader.Output3;
}
// Load VAE override if enabled
if (IsVaeSelectionEnabled && SelectedVae is { IsNone: false, IsDefault: false })
{
var vaeLoader = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.VAELoader
{
Name = "VAELoader",
VaeName =
SelectedVae?.RelativePath
?? throw new ValidationException("VAE enabled but not selected")
}
);
e.Builder.Connections.PrimaryVAE = vaeLoader.Output;
}
// Clip skip all models if enabled
if (IsClipSkipEnabled)
{
foreach (var (modelName, model) in e.Builder.Connections.Models)
{
if (model.Clip is not { } modelClip)
continue;
var clipSetLastLayer = e.Nodes.AddTypedNode(
new ComfyNodeBuilder.CLIPSetLastLayer
{
Name = $"CLIP_Skip_{modelName}",
Clip = modelClip,
// Need to convert to negative indexing from (1 to 24) to (-1 to -24)
StopAtClipLayer = -ClipSkip
}
);
model.Clip = clipSetLastLayer.Output;
}
}
}
/// <inheritdoc />
public override JsonObject SaveStateToJsonObject()
{
return SerializeModel(
new ModelCardModel
{
SelectedModelName = SelectedModel?.RelativePath,
SelectedVaeName = SelectedVae?.RelativePath,
SelectedRefinerName = SelectedRefiner?.RelativePath,
ClipSkip = ClipSkip,
IsVaeSelectionEnabled = IsVaeSelectionEnabled,
IsRefinerSelectionEnabled = IsRefinerSelectionEnabled,
IsClipSkipEnabled = IsClipSkipEnabled
}
);
}
/// <inheritdoc />
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel<ModelCardModel>(state);
SelectedModel = model.SelectedModelName is null
? null
: ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName);
SelectedVae = model.SelectedVaeName is null
? HybridModelFile.Default
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
SelectedRefiner = model.SelectedRefinerName is null
? HybridModelFile.None
: ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedRefinerName);
ClipSkip = model.ClipSkip;
IsVaeSelectionEnabled = model.IsVaeSelectionEnabled;
IsRefinerSelectionEnabled = model.IsRefinerSelectionEnabled;
IsClipSkipEnabled = model.IsClipSkipEnabled;
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
if (parameters.ModelName is not { } paramsModelName)
return;
var currentModels = ClientManager.Models;
HybridModelFile? model;
// First try hash match
if (parameters.ModelHash is not null)
{
model = currentModels.FirstOrDefault(
m =>
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256
&& sha256.StartsWith(parameters.ModelHash, StringComparison.InvariantCultureIgnoreCase)
);
}
else
{
// Name matches
model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(m => m.ShortDisplayName.StartsWith(paramsModelName));
}
if (model is not null)
{
SelectedModel = model;
}
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with
{
ModelName = SelectedModel?.FileName,
ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256
};
}
internal class ModelCardModel
{
public string? SelectedModelName { get; init; }
public string? SelectedRefinerName { get; init; }
public string? SelectedVaeName { get; init; }
public int ClipSkip { get; init; } = 1;
public bool IsVaeSelectionEnabled { get; init; }
public bool IsRefinerSelectionEnabled { get; init; }
public bool IsClipSkipEnabled { get; init; }
}
}