diff --git a/StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs b/StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs index f496abc4..44d019fd 100644 --- a/StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs +++ b/StabilityMatrix.Core/Attributes/TypedNodeOptionsAttribute.cs @@ -1,4 +1,5 @@ using StabilityMatrix.Core.Models.Api.Comfy.Nodes; +using StabilityMatrix.Core.Models.Packages.Extensions; namespace StabilityMatrix.Core.Attributes; @@ -11,4 +12,9 @@ public class TypedNodeOptionsAttribute : Attribute public string? Name { get; init; } public string[]? RequiredExtensions { get; init; } + + public IEnumerable GetRequiredExtensionSpecifiers() + { + return RequiredExtensions?.Select(ExtensionSpecifier.Parse) ?? Enumerable.Empty(); + } } diff --git a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs index 5fbe4464..da088767 100644 --- a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs +++ b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/NodeDictionary.cs @@ -1,10 +1,12 @@ using System.ComponentModel; using System.Reflection; using System.Text.Json.Serialization; +using KGySoft.CoreLibraries; using OneOf; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; +using StabilityMatrix.Core.Models.Packages.Extensions; namespace StabilityMatrix.Core.Models.Api.Comfy.Nodes; @@ -19,7 +21,10 @@ public class NodeDictionary : Dictionary /// When inserting TypedNodes, this holds a mapping of ClassType to required extensions /// [JsonIgnore] - public Dictionary ClassTypeRequiredExtensions { get; } = new(); + public Dictionary ClassTypeRequiredExtensions { get; } = new(); + + public IEnumerable RequiredExtensions => + ClassTypeRequiredExtensions.Values.SelectMany(x => x); /// /// Finds a unique node name given a base name, by appending _2, _3, etc. @@ -63,7 +68,11 @@ public class NodeDictionary : Dictionary { if (options.RequiredExtensions != null) { - ClassTypeRequiredExtensions[namedNode.ClassType] = options.RequiredExtensions; + ClassTypeRequiredExtensions.AddOrUpdate( + namedNode.ClassType, + _ => options.GetRequiredExtensionSpecifiers().ToArray(), + (_, specifiers) => options.GetRequiredExtensionSpecifiers().Concat(specifiers).ToArray() + ); } } diff --git a/StabilityMatrix.Core/Models/Packages/Extensions/ExtensionSpecifier.cs b/StabilityMatrix.Core/Models/Packages/Extensions/ExtensionSpecifier.cs new file mode 100644 index 00000000..13003ca0 --- /dev/null +++ b/StabilityMatrix.Core/Models/Packages/Extensions/ExtensionSpecifier.cs @@ -0,0 +1,105 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.RegularExpressions; +using JetBrains.Annotations; +using Semver; +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Core.Models.Packages.Extensions; + +/// +/// Extension specifier with optional version constraints. +/// +[PublicAPI] +public partial class ExtensionSpecifier +{ + public required string Name { get; init; } + + public string? Constraint { get; init; } + + public string? Version { get; init; } + + public string? VersionConstraint => Constraint is null || Version is null ? null : Constraint + Version; + + public bool TryGetSemVersionRange([NotNullWhen(true)] out SemVersionRange? semVersionRange) + { + if (!string.IsNullOrEmpty(VersionConstraint)) + { + return SemVersionRange.TryParse( + VersionConstraint, + SemVersionRangeOptions.Loose, + out semVersionRange + ); + } + + semVersionRange = null; + return false; + } + + public static ExtensionSpecifier Parse(string value) + { + TryParse(value, true, out var packageSpecifier); + + return packageSpecifier!; + } + + public static bool TryParse(string value, [NotNullWhen(true)] out ExtensionSpecifier? extensionSpecifier) + { + return TryParse(value, false, out extensionSpecifier); + } + + private static bool TryParse( + string value, + bool throwOnFailure, + [NotNullWhen(true)] out ExtensionSpecifier? packageSpecifier + ) + { + var match = ExtensionSpecifierRegex().Match(value); + if (!match.Success) + { + if (throwOnFailure) + { + throw new ArgumentException($"Invalid extension specifier: {value}"); + } + + packageSpecifier = null; + return false; + } + + packageSpecifier = new ExtensionSpecifier + { + Name = match.Groups["extension_name"].Value, + Constraint = match.Groups["version_constraint"].Value, + Version = match.Groups["version"].Value + }; + + return true; + } + + /// + public override string ToString() + { + return Name + VersionConstraint; + } + + public static implicit operator Argument(ExtensionSpecifier specifier) + { + return specifier.VersionConstraint is null + ? new Argument(specifier.Name) + : new Argument((specifier.Name, specifier.VersionConstraint)); + } + + public static implicit operator ExtensionSpecifier(string specifier) + { + return Parse(specifier); + } + + /// + /// Regex to match a pip package specifier. + /// + [GeneratedRegex( + @"(?\S+)\s*(?==|>=|<=|>|<|~=|!=)?\s*(?[a-zA-Z0-9_.]+)?", + RegexOptions.CultureInvariant, + 5000 + )] + private static partial Regex ExtensionSpecifierRegex(); +}