Browse Source

Add Custom VAE loading

pull/165/head
Ionite 1 year ago
parent
commit
2dec41ce00
No known key found for this signature in database
  1. 8
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  2. 51
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  3. 2
      StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs
  4. 15
      StabilityMatrix.Core/Models/Api/Comfy/ComfySampler.cs
  5. 28
      StabilityMatrix.Core/Models/Api/Comfy/ComfyUpscaler.cs
  6. 14
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  7. 20
      StabilityMatrix.Core/Models/Database/LocalModelFile.cs
  8. 20
      StabilityMatrix.Core/Models/HybridModelFile.Design.cs
  9. 61
      StabilityMatrix.Core/Models/HybridModelFile.cs
  10. 2
      StabilityMatrix.Core/Services/IModelIndexService.cs
  11. 12
      StabilityMatrix.Core/Services/ModelIndexService.cs

8
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -16,7 +16,7 @@ public partial class ModelCardViewModel : LoadableViewModelBase
private HybridModelFile? selectedModel; private HybridModelFile? selectedModel;
[ObservableProperty] [ObservableProperty]
private HybridModelFile? selectedVae; private HybridModelFile? selectedVae = HybridModelFile.Default;
[ObservableProperty] [ObservableProperty]
private bool isVaeSelectionEnabled; private bool isVaeSelectionEnabled;
@ -49,10 +49,10 @@ public partial class ModelCardViewModel : LoadableViewModelBase
var model = DeserializeModel<ModelCardModel>(state); var model = DeserializeModel<ModelCardModel>(state);
SelectedModel = model.SelectedModelName is null ? null SelectedModel = model.SelectedModelName is null ? null
: ClientManager.Models!.FirstOrDefault(x => x.FileName == model.SelectedModelName); : ClientManager.Models.FirstOrDefault(x => x.FileName == model.SelectedModelName);
SelectedVae = model.SelectedVaeName is null ? null SelectedVae = model.SelectedVaeName is null ? HybridModelFile.Default
: ClientManager.VaeModels!.FirstOrDefault(x => x.FileName == model.SelectedVaeName); : ClientManager.VaeModels.FirstOrDefault(x => x.FileName == model.SelectedVaeName);
} }
internal class ModelCardModel internal class ModelCardModel

51
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,12 +1,21 @@
using System.Diagnostics; using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
using System.Threading.Tasks;
using AvaloniaEdit.Document; using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -16,6 +25,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase
{ {
public ICompletionProvider CompletionProvider { get; } public ICompletionProvider CompletionProvider { get; }
public ITokenizerProvider TokenizerProvider { get; } public ITokenizerProvider TokenizerProvider { get; }
public SharedState SharedState { get; }
public TextDocument PromptDocument { get; } = new(); public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new(); public TextDocument NegativePromptDocument { get; } = new();
@ -27,10 +37,12 @@ public partial class PromptCardViewModel : LoadableViewModelBase
public PromptCardViewModel( public PromptCardViewModel(
ICompletionProvider completionProvider, ICompletionProvider completionProvider,
ITokenizerProvider tokenizerProvider, ITokenizerProvider tokenizerProvider,
ISettingsManager settingsManager) ISettingsManager settingsManager,
SharedState sharedState)
{ {
CompletionProvider = completionProvider; CompletionProvider = completionProvider;
TokenizerProvider = tokenizerProvider; TokenizerProvider = tokenizerProvider;
SharedState = sharedState;
settingsManager.RelayPropertyFor(this, settingsManager.RelayPropertyFor(this,
vm => vm.IsAutoCompletionEnabled, vm => vm.IsAutoCompletionEnabled,
@ -38,9 +50,40 @@ public partial class PromptCardViewModel : LoadableViewModelBase
true); true);
} }
partial void OnIsAutoCompletionEnabledChanged(bool value) /// <summary>
/// Processes current positive prompt text into a Prompt object
/// </summary>
public Prompt GetPrompt()
{ {
Debug.WriteLine("OnIsAutoCompletionEnabledChanged: " + value); return Prompt.FromRawText(PromptDocument.Text, TokenizerProvider);
}
/// <summary>
/// Processes current negative prompt text into a Prompt object
/// </summary>
public Prompt GetNegativePrompt()
{
return Prompt.FromRawText(NegativePromptDocument.Text, TokenizerProvider);
}
[RelayCommand]
private async Task DebugShowTokens()
{
var prompt = GetPrompt();
var tokens = prompt.TokenizeResult.Tokens;
var builder = new StringBuilder();
builder.AppendLine($"Tokens ({tokens.Length}):");
builder.AppendLine("```csharp");
builder.AppendLine(prompt.GetDebugText());
builder.AppendLine("```");
var dialog = DialogHelper.CreateMarkdownDialog(builder.ToString(), "Prompt Tokens");
dialog.MinDialogWidth = 800;
dialog.MaxDialogHeight = 1000;
dialog.MaxDialogWidth = 1000;
await dialog.ShowAsync();
} }
/// <inheritdoc /> /// <inheritdoc />

2
StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs

@ -14,7 +14,7 @@ public partial class UpscalerCardViewModel : LoadableViewModelBase
{ {
[ObservableProperty] private double scale = 1; [ObservableProperty] private double scale = 1;
[ObservableProperty] private ComfyUpscaler? selectedUpscaler; [ObservableProperty] private ComfyUpscaler? selectedUpscaler = ComfyUpscaler.Defaults[0];
public IInferenceClientManager ClientManager { get; } public IInferenceClientManager ClientManager { get; }

15
StabilityMatrix.Core/Models/Api/Comfy/ComfySampler.cs

@ -31,4 +31,19 @@ public readonly record struct ComfySampler(string Name)
public string DisplayName => public string DisplayName =>
ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name; ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name;
private sealed class NameEqualityComparer : IEqualityComparer<ComfySampler>
{
public bool Equals(ComfySampler x, ComfySampler y)
{
return x.Name == y.Name;
}
public int GetHashCode(ComfySampler obj)
{
return obj.Name.GetHashCode();
}
}
public static IEqualityComparer<ComfySampler> Comparer { get; } = new NameEqualityComparer();
} }

28
StabilityMatrix.Core/Models/Api/Comfy/ComfyUpscaler.cs

@ -1,4 +1,6 @@
namespace StabilityMatrix.Core.Models.Api.Comfy; using System.Collections.Immutable;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type) public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
{ {
@ -12,6 +14,9 @@ public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
["bislerp"] = "Bislerp", ["bislerp"] = "Bislerp",
}; };
public static IReadOnlyList<ComfyUpscaler> Defaults { get; } =
ConvertDict.Keys.Select(k => new ComfyUpscaler(k, ComfyUpscalerType.Latent)).ToImmutableArray();
public string DisplayType public string DisplayType
{ {
get get
@ -35,6 +40,12 @@ public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
return ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name; return ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name;
} }
if (Type == ComfyUpscalerType.ESRGAN)
{
// Remove file extensions
return Path.GetFileNameWithoutExtension(Name);
}
return Name; return Name;
} }
} }
@ -52,4 +63,19 @@ public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
return DisplayName; return DisplayName;
} }
} }
private sealed class NameTypeEqualityComparer : IEqualityComparer<ComfyUpscaler>
{
public bool Equals(ComfyUpscaler x, ComfyUpscaler y)
{
return x.Name == y.Name && x.Type == y.Type;
}
public int GetHashCode(ComfyUpscaler obj)
{
return HashCode.Combine(obj.Name, (int) obj.Type);
}
}
public static IEqualityComparer<ComfyUpscaler> Comparer { get; } = new NameTypeEqualityComparer();
} }

14
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -127,6 +127,20 @@ public class ComfyNodeBuilder
}; };
} }
public static NamedComfyNode<VAENodeConnection> VAELoader(
string name,
string vaeModelName)
{
return new NamedComfyNode<VAENodeConnection>(name)
{
ClassType = "VAELoader",
Inputs = new Dictionary<string, object?>
{
["vae_name"] = vaeModelName
}
};
}
public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae) public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
{ {
return nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output; return nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output;

20
StabilityMatrix.Core/Models/Database/LocalModelFile.cs

@ -50,6 +50,26 @@ public class LocalModelFile
public string? PreviewImageFullPathGlobal public string? PreviewImageFullPathGlobal
=> GetPreviewImageFullPath(GlobalConfig.LibraryDir.JoinDir("Models")); => GetPreviewImageFullPath(GlobalConfig.LibraryDir.JoinDir("Models"));
protected bool Equals(LocalModelFile other)
{
return RelativePath == other.RelativePath;
}
/// <inheritdoc />
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((LocalModelFile) obj);
}
/// <inheritdoc />
public override int GetHashCode()
{
return RelativePath.GetHashCode();
}
public static readonly HashSet<string> SupportedCheckpointExtensions = public static readonly HashSet<string> SupportedCheckpointExtensions =
new() { ".safetensors", ".pt", ".ckpt", ".pth", ".bin" }; new() { ".safetensors", ".pt", ".ckpt", ".pth", ".bin" };
public static readonly HashSet<string> SupportedImageExtensions = public static readonly HashSet<string> SupportedImageExtensions =

20
StabilityMatrix.Core/Models/HybridModelFile.Design.cs

@ -1,20 +0,0 @@
using System.ComponentModel;
namespace StabilityMatrix.Core.Models;
/// <summary>
/// Design time extensions for <see cref="HybridModelFile"/>.
/// </summary>
[DesignOnly(true)]
public partial record HybridModelFile
{
/// <summary>
/// Whether this instance is the default model.
/// </summary>
public bool IsDefault => ReferenceEquals(this, Default);
/// <summary>
/// Whether this instance is no model.
/// </summary>
public bool IsNone => ReferenceEquals(this, None);
}

61
StabilityMatrix.Core/Models/HybridModelFile.cs

@ -8,17 +8,17 @@ namespace StabilityMatrix.Core.Models;
/// Model file union that may be remote or local. /// Model file union that may be remote or local.
/// </summary> /// </summary>
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public partial record HybridModelFile public record HybridModelFile
{ {
/// <summary> /// <summary>
/// Singleton instance of <see cref="HybridModelFile"/> that represents use of a default model. /// Singleton instance of <see cref="HybridModelFile"/> that represents use of a default model.
/// </summary> /// </summary>
public static HybridModelFile Default => new(); public static HybridModelFile Default { get; } = FromRemote("@default");
/// <summary> /// <summary>
/// Singleton instance of <see cref="HybridModelFile"/> that represents no model. /// Singleton instance of <see cref="HybridModelFile"/> that represents no model.
/// </summary> /// </summary>
public static HybridModelFile None => new(); public static HybridModelFile None { get; } = FromRemote("@none");
private string? RemoteName { get; init; } private string? RemoteName { get; init; }
@ -34,7 +34,23 @@ public partial record HybridModelFile
? RemoteName : Local.FileName; ? RemoteName : Local.FileName;
[JsonIgnore] [JsonIgnore]
public string ShortDisplayName => Path.GetFileNameWithoutExtension(FileName); public string ShortDisplayName
{
get
{
if (IsNone)
{
return "None";
}
if (IsDefault)
{
return "Default";
}
return Path.GetFileNameWithoutExtension(FileName);
}
}
public static HybridModelFile FromLocal(LocalModelFile local) public static HybridModelFile FromLocal(LocalModelFile local)
{ {
@ -51,4 +67,41 @@ public partial record HybridModelFile
RemoteName = remoteName RemoteName = remoteName
}; };
} }
public string GetId()
{
return $"{FileName};{IsNone};{IsDefault}";
}
private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer<HybridModelFile>
{
public bool Equals(HybridModelFile? x, HybridModelFile? y)
{
if (ReferenceEquals(x, y)) return true;
if (ReferenceEquals(x, null)) return false;
if (ReferenceEquals(y, null)) return false;
if (x.GetType() != y.GetType()) return false;
return Equals(x.FileName, y.FileName)
&& x.IsNone == y.IsNone
&& x.IsDefault == y.IsDefault;
}
public int GetHashCode(HybridModelFile obj)
{
return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.FileName);
}
}
/// <summary>
/// Whether this instance is the default model.
/// </summary>
public bool IsDefault => ReferenceEquals(this, Default);
/// <summary>
/// Whether this instance is no model.
/// </summary>
public bool IsNone => ReferenceEquals(this, None);
public static IEqualityComparer<HybridModelFile> Comparer { get; } = new RemoteNameLocalEqualityComparer();
} }

2
StabilityMatrix.Core/Services/IModelIndexService.cs

@ -5,6 +5,8 @@ namespace StabilityMatrix.Core.Services;
public interface IModelIndexService public interface IModelIndexService
{ {
Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; }
/// <summary> /// <summary>
/// Refreshes the local model file index. /// Refreshes the local model file index.
/// </summary> /// </summary>

12
StabilityMatrix.Core/Services/ModelIndexService.cs

@ -1,6 +1,7 @@
using System.Diagnostics; using System.Diagnostics;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Database; using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.FileInterfaces;
@ -13,6 +14,8 @@ public class ModelIndexService : IModelIndexService
private readonly ILiteDbContext liteDbContext; private readonly ILiteDbContext liteDbContext;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; private set; } = new();
public ModelIndexService( public ModelIndexService(
ILogger<ModelIndexService> logger, ILogger<ModelIndexService> logger,
ILiteDbContext liteDbContext, ILiteDbContext liteDbContext,
@ -62,6 +65,8 @@ public class ModelIndexService : IModelIndexService
var added = 0; var added = 0;
var newIndex = new Dictionary<SharedFolderType, List<LocalModelFile>>();
foreach ( foreach (
var file in modelsDir.Info var file in modelsDir.Info
.EnumerateFiles("*.*", SearchOption.AllDirectories) .EnumerateFiles("*.*", SearchOption.AllDirectories)
@ -113,9 +118,16 @@ public class ModelIndexService : IModelIndexService
// Insert into database // Insert into database
await localModelFiles.InsertAsync(localModel).ConfigureAwait(false); await localModelFiles.InsertAsync(localModel).ConfigureAwait(false);
// Add to index
var list = newIndex.GetOrAdd(sharedFolderType);
list.Add(localModel);
added++; added++;
} }
// Update index
ModelIndex = newIndex;
// Record end of actual indexing // Record end of actual indexing
var indexEnd = stopwatch.Elapsed; var indexEnd = stopwatch.Elapsed;

Loading…
Cancel
Save