Browse Source

Add Custom VAE loading

pull/165/head
Ionite 1 year ago
parent
commit
2dec41ce00
No known key found for this signature in database
  1. 8
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  2. 51
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  3. 4
      StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs
  4. 15
      StabilityMatrix.Core/Models/Api/Comfy/ComfySampler.cs
  5. 28
      StabilityMatrix.Core/Models/Api/Comfy/ComfyUpscaler.cs
  6. 14
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
  7. 20
      StabilityMatrix.Core/Models/Database/LocalModelFile.cs
  8. 20
      StabilityMatrix.Core/Models/HybridModelFile.Design.cs
  9. 65
      StabilityMatrix.Core/Models/HybridModelFile.cs
  10. 2
      StabilityMatrix.Core/Services/IModelIndexService.cs
  11. 12
      StabilityMatrix.Core/Services/ModelIndexService.cs

8
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -16,7 +16,7 @@ public partial class ModelCardViewModel : LoadableViewModelBase
private HybridModelFile? selectedModel;
[ObservableProperty]
private HybridModelFile? selectedVae;
private HybridModelFile? selectedVae = HybridModelFile.Default;
[ObservableProperty]
private bool isVaeSelectionEnabled;
@ -49,10 +49,10 @@ public partial class ModelCardViewModel : LoadableViewModelBase
var model = DeserializeModel<ModelCardModel>(state);
SelectedModel = model.SelectedModelName is null ? null
: ClientManager.Models!.FirstOrDefault(x => x.FileName == model.SelectedModelName);
: ClientManager.Models.FirstOrDefault(x => x.FileName == model.SelectedModelName);
SelectedVae = model.SelectedVaeName is null ? null
: ClientManager.VaeModels!.FirstOrDefault(x => x.FileName == model.SelectedVaeName);
SelectedVae = model.SelectedVaeName is null ? HybridModelFile.Default
: ClientManager.VaeModels.FirstOrDefault(x => x.FileName == model.SelectedVaeName);
}
internal class ModelCardModel

51
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -1,12 +1,21 @@
using System.Diagnostics;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using AvaloniaEdit.Document;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
@ -16,6 +25,7 @@ public partial class PromptCardViewModel : LoadableViewModelBase
{
public ICompletionProvider CompletionProvider { get; }
public ITokenizerProvider TokenizerProvider { get; }
public SharedState SharedState { get; }
public TextDocument PromptDocument { get; } = new();
public TextDocument NegativePromptDocument { get; } = new();
@ -27,22 +37,55 @@ public partial class PromptCardViewModel : LoadableViewModelBase
public PromptCardViewModel(
ICompletionProvider completionProvider,
ITokenizerProvider tokenizerProvider,
ISettingsManager settingsManager)
ISettingsManager settingsManager,
SharedState sharedState)
{
CompletionProvider = completionProvider;
TokenizerProvider = tokenizerProvider;
SharedState = sharedState;
settingsManager.RelayPropertyFor(this,
vm => vm.IsAutoCompletionEnabled,
settings => settings.IsPromptCompletionEnabled,
true);
}
/// <summary>
/// Processes current positive prompt text into a Prompt object
/// </summary>
public Prompt GetPrompt()
{
return Prompt.FromRawText(PromptDocument.Text, TokenizerProvider);
}
partial void OnIsAutoCompletionEnabledChanged(bool value)
/// <summary>
/// Processes current negative prompt text into a Prompt object
/// </summary>
public Prompt GetNegativePrompt()
{
Debug.WriteLine("OnIsAutoCompletionEnabledChanged: " + value);
return Prompt.FromRawText(NegativePromptDocument.Text, TokenizerProvider);
}
[RelayCommand]
private async Task DebugShowTokens()
{
var prompt = GetPrompt();
var tokens = prompt.TokenizeResult.Tokens;
var builder = new StringBuilder();
builder.AppendLine($"Tokens ({tokens.Length}):");
builder.AppendLine("```csharp");
builder.AppendLine(prompt.GetDebugText());
builder.AppendLine("```");
var dialog = DialogHelper.CreateMarkdownDialog(builder.ToString(), "Prompt Tokens");
dialog.MinDialogWidth = 800;
dialog.MaxDialogHeight = 1000;
dialog.MaxDialogWidth = 1000;
await dialog.ShowAsync();
}
/// <inheritdoc />
public override JsonObject SaveStateToJsonObject()
{

4
StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs

@ -13,8 +13,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
public partial class UpscalerCardViewModel : LoadableViewModelBase
{
[ObservableProperty] private double scale = 1;
[ObservableProperty] private ComfyUpscaler? selectedUpscaler;
[ObservableProperty] private ComfyUpscaler? selectedUpscaler = ComfyUpscaler.Defaults[0];
public IInferenceClientManager ClientManager { get; }

15
StabilityMatrix.Core/Models/Api/Comfy/ComfySampler.cs

@ -31,4 +31,19 @@ public readonly record struct ComfySampler(string Name)
public string DisplayName =>
ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name;
private sealed class NameEqualityComparer : IEqualityComparer<ComfySampler>
{
public bool Equals(ComfySampler x, ComfySampler y)
{
return x.Name == y.Name;
}
public int GetHashCode(ComfySampler obj)
{
return obj.Name.GetHashCode();
}
}
public static IEqualityComparer<ComfySampler> Comparer { get; } = new NameEqualityComparer();
}

28
StabilityMatrix.Core/Models/Api/Comfy/ComfyUpscaler.cs

@ -1,4 +1,6 @@
namespace StabilityMatrix.Core.Models.Api.Comfy;
using System.Collections.Immutable;
namespace StabilityMatrix.Core.Models.Api.Comfy;
public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
{
@ -11,6 +13,9 @@ public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
["bicubic"] = "Bicubic",
["bislerp"] = "Bislerp",
};
public static IReadOnlyList<ComfyUpscaler> Defaults { get; } =
ConvertDict.Keys.Select(k => new ComfyUpscaler(k, ComfyUpscalerType.Latent)).ToImmutableArray();
public string DisplayType
{
@ -34,6 +39,12 @@ public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
{
return ConvertDict.TryGetValue(Name, out var displayName) ? displayName : Name;
}
if (Type == ComfyUpscalerType.ESRGAN)
{
// Remove file extensions
return Path.GetFileNameWithoutExtension(Name);
}
return Name;
}
@ -52,4 +63,19 @@ public readonly record struct ComfyUpscaler(string Name, ComfyUpscalerType Type)
return DisplayName;
}
}
private sealed class NameTypeEqualityComparer : IEqualityComparer<ComfyUpscaler>
{
public bool Equals(ComfyUpscaler x, ComfyUpscaler y)
{
return x.Name == y.Name && x.Type == y.Type;
}
public int GetHashCode(ComfyUpscaler obj)
{
return HashCode.Combine(obj.Name, (int) obj.Type);
}
}
public static IEqualityComparer<ComfyUpscaler> Comparer { get; } = new NameTypeEqualityComparer();
}

14
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -126,6 +126,20 @@ public class ComfyNodeBuilder
}
};
}
public static NamedComfyNode<VAENodeConnection> VAELoader(
string name,
string vaeModelName)
{
return new NamedComfyNode<VAENodeConnection>(name)
{
ClassType = "VAELoader",
Inputs = new Dictionary<string, object?>
{
["vae_name"] = vaeModelName
}
};
}
public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
{

20
StabilityMatrix.Core/Models/Database/LocalModelFile.cs

@ -50,6 +50,26 @@ public class LocalModelFile
public string? PreviewImageFullPathGlobal
=> GetPreviewImageFullPath(GlobalConfig.LibraryDir.JoinDir("Models"));
protected bool Equals(LocalModelFile other)
{
return RelativePath == other.RelativePath;
}
/// <inheritdoc />
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((LocalModelFile) obj);
}
/// <inheritdoc />
public override int GetHashCode()
{
return RelativePath.GetHashCode();
}
public static readonly HashSet<string> SupportedCheckpointExtensions =
new() { ".safetensors", ".pt", ".ckpt", ".pth", ".bin" };
public static readonly HashSet<string> SupportedImageExtensions =

20
StabilityMatrix.Core/Models/HybridModelFile.Design.cs

@ -1,20 +0,0 @@
using System.ComponentModel;
namespace StabilityMatrix.Core.Models;
/// <summary>
/// Design time extensions for <see cref="HybridModelFile"/>.
/// </summary>
[DesignOnly(true)]
public partial record HybridModelFile
{
/// <summary>
/// Whether this instance is the default model.
/// </summary>
public bool IsDefault => ReferenceEquals(this, Default);
/// <summary>
/// Whether this instance is no model.
/// </summary>
public bool IsNone => ReferenceEquals(this, None);
}

65
StabilityMatrix.Core/Models/HybridModelFile.cs

@ -8,18 +8,18 @@ namespace StabilityMatrix.Core.Models;
/// Model file union that may be remote or local.
/// </summary>
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public partial record HybridModelFile
public record HybridModelFile
{
/// <summary>
/// Singleton instance of <see cref="HybridModelFile"/> that represents use of a default model.
/// </summary>
public static HybridModelFile Default => new();
public static HybridModelFile Default { get; } = FromRemote("@default");
/// <summary>
/// Singleton instance of <see cref="HybridModelFile"/> that represents no model.
/// </summary>
public static HybridModelFile None => new();
public static HybridModelFile None { get; } = FromRemote("@none");
private string? RemoteName { get; init; }
public LocalModelFile? Local { get; init; }
@ -34,8 +34,24 @@ public partial record HybridModelFile
? RemoteName : Local.FileName;
[JsonIgnore]
public string ShortDisplayName => Path.GetFileNameWithoutExtension(FileName);
public string ShortDisplayName
{
get
{
if (IsNone)
{
return "None";
}
if (IsDefault)
{
return "Default";
}
return Path.GetFileNameWithoutExtension(FileName);
}
}
public static HybridModelFile FromLocal(LocalModelFile local)
{
return new HybridModelFile
@ -51,4 +67,41 @@ public partial record HybridModelFile
RemoteName = remoteName
};
}
public string GetId()
{
return $"{FileName};{IsNone};{IsDefault}";
}
private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer<HybridModelFile>
{
public bool Equals(HybridModelFile? x, HybridModelFile? y)
{
if (ReferenceEquals(x, y)) return true;
if (ReferenceEquals(x, null)) return false;
if (ReferenceEquals(y, null)) return false;
if (x.GetType() != y.GetType()) return false;
return Equals(x.FileName, y.FileName)
&& x.IsNone == y.IsNone
&& x.IsDefault == y.IsDefault;
}
public int GetHashCode(HybridModelFile obj)
{
return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.FileName);
}
}
/// <summary>
/// Whether this instance is the default model.
/// </summary>
public bool IsDefault => ReferenceEquals(this, Default);
/// <summary>
/// Whether this instance is no model.
/// </summary>
public bool IsNone => ReferenceEquals(this, None);
public static IEqualityComparer<HybridModelFile> Comparer { get; } = new RemoteNameLocalEqualityComparer();
}

2
StabilityMatrix.Core/Services/IModelIndexService.cs

@ -5,6 +5,8 @@ namespace StabilityMatrix.Core.Services;
public interface IModelIndexService
{
Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; }
/// <summary>
/// Refreshes the local model file index.
/// </summary>

12
StabilityMatrix.Core/Services/ModelIndexService.cs

@ -1,6 +1,7 @@
using System.Diagnostics;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
@ -13,6 +14,8 @@ public class ModelIndexService : IModelIndexService
private readonly ILiteDbContext liteDbContext;
private readonly ISettingsManager settingsManager;
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; private set; } = new();
public ModelIndexService(
ILogger<ModelIndexService> logger,
ILiteDbContext liteDbContext,
@ -62,6 +65,8 @@ public class ModelIndexService : IModelIndexService
var added = 0;
var newIndex = new Dictionary<SharedFolderType, List<LocalModelFile>>();
foreach (
var file in modelsDir.Info
.EnumerateFiles("*.*", SearchOption.AllDirectories)
@ -113,9 +118,16 @@ public class ModelIndexService : IModelIndexService
// Insert into database
await localModelFiles.InsertAsync(localModel).ConfigureAwait(false);
// Add to index
var list = newIndex.GetOrAdd(sharedFolderType);
list.Add(localModel);
added++;
}
// Update index
ModelIndex = newIndex;
// Record end of actual indexing
var indexEnd = stopwatch.Elapsed;

Loading…
Cancel
Save