diff --git a/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs b/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs index 2de1f45d..7422eebb 100644 --- a/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs +++ b/StabilityMatrix.Core/Converters/Json/DefaultUnknownEnumConverter.cs @@ -1,38 +1,95 @@ -using System.Text.Json; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Runtime.Serialization; +using System.Text.Json; using System.Text.Json.Serialization; -using StabilityMatrix.Core.Extensions; namespace StabilityMatrix.Core.Converters.Json; -public class DefaultUnknownEnumConverter : JsonConverter +public class DefaultUnknownEnumConverter< + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] T +> : JsonConverter where T : Enum { + /// + /// Lazy initialization for . + /// + private readonly Lazy> _enumMemberValuesLazy = + new( + () => + typeof(T) + .GetFields() + .Where(field => field.IsStatic) + .Select( + field => + new + { + FieldName = field.Name, + FieldValue = (T)field.GetValue(null)!, + EnumMemberValue = field + .GetCustomAttributes(false) + .FirstOrDefault() + ?.Value?.ToString() + } + ) + .ToDictionary(x => x.EnumMemberValue ?? x.FieldName, x => x.FieldValue) + ); + + /// + /// Gets a dictionary of enum member values, keyed by the EnumMember attribute value, or the field name if no EnumMember attribute is present. + /// + private Dictionary EnumMemberValues => _enumMemberValuesLazy.Value; + + /// + /// Lazy initialization for . + /// + private readonly Lazy> _enumMemberNamesLazy; + + /// + /// Gets a dictionary of enum member names, keyed by the enum member value. + /// + private Dictionary EnumMemberNames => _enumMemberNamesLazy.Value; + + /// + /// Gets the value of the "Unknown" enum member, or the 0 value if no "Unknown" member is present. + /// + private T UnknownValue => + EnumMemberValues.TryGetValue("Unknown", out var res) ? res : (T)Enum.ToObject(typeof(T), 0); + + /// + public override bool HandleNull => true; + + public DefaultUnknownEnumConverter() + { + _enumMemberNamesLazy = new Lazy>( + () => EnumMemberValues.ToDictionary(x => x.Value, x => x.Key) + ); + } + + /// public override T Read( ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options ) { - if (reader.TokenType != JsonTokenType.String) - { - throw new JsonException(); - } - - var enumText = reader.GetString()?.Replace(" ", "_"); - if (Enum.TryParse(typeof(T), enumText, true, out var result)) + if (reader.TokenType is not (JsonTokenType.String or JsonTokenType.PropertyName)) { - return (T)result!; + throw new JsonException("Expected String or PropertyName token"); } - // Unknown value handling - if (Enum.TryParse(typeof(T), "Unknown", true, out var unknownResult)) + if (reader.GetString() is { } readerString) { - return (T)unknownResult!; + if (EnumMemberValues.TryGetValue(readerString, out var enumMemberValue)) + { + return enumMemberValue; + } } - throw new JsonException($"Unable to parse '{enumText}' to enum '{typeof(T)}'."); + return UnknownValue; } + /// public override void Write(Utf8JsonWriter writer, T? value, JsonSerializerOptions options) { if (value == null) @@ -41,7 +98,7 @@ public class DefaultUnknownEnumConverter : JsonConverter return; } - writer.WriteStringValue(value.GetStringValue().Replace("_", " ")); + writer.WriteStringValue(EnumMemberNames[value]); } /// @@ -49,41 +106,12 @@ public class DefaultUnknownEnumConverter : JsonConverter ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options - ) - { - if (reader.TokenType != JsonTokenType.PropertyName) - { - throw new JsonException(); - } - - var enumText = reader.GetString()?.Replace(" ", "_"); - if (Enum.TryParse(typeof(T), enumText, true, out var result)) - { - return (T)result!; - } - - // Unknown value handling - if (Enum.TryParse(typeof(T), "Unknown", true, out var unknownResult)) - { - return (T)unknownResult!; - } - - throw new JsonException($"Unable to parse '{enumText}' to enum '{typeof(T)}'."); - } + ) => Read(ref reader, typeToConvert, options); /// public override void WriteAsPropertyName( Utf8JsonWriter writer, T? value, JsonSerializerOptions options - ) - { - if (value == null) - { - writer.WriteNullValue(); - return; - } - - writer.WritePropertyName(value.GetStringValue().Replace("_", " ")); - } + ) => Write(writer, value, options); } diff --git a/StabilityMatrix.Core/Models/Api/CivitFileType.cs b/StabilityMatrix.Core/Models/Api/CivitFileType.cs index a2a59fda..b4108924 100644 --- a/StabilityMatrix.Core/Models/Api/CivitFileType.cs +++ b/StabilityMatrix.Core/Models/Api/CivitFileType.cs @@ -1,4 +1,5 @@ -using System.Text.Json.Serialization; +using System.Runtime.Serialization; +using System.Text.Json.Serialization; using StabilityMatrix.Core.Converters.Json; namespace StabilityMatrix.Core.Models.Api; @@ -6,8 +7,10 @@ namespace StabilityMatrix.Core.Models.Api; [JsonConverter(typeof(DefaultUnknownEnumConverter))] public enum CivitFileType { + Unknown, Model, VAE, - Training_Data, - Unknown, + + [EnumMember(Value = "Training Data")] + TrainingData } diff --git a/StabilityMatrix.Core/Models/Api/CivitModelType.cs b/StabilityMatrix.Core/Models/Api/CivitModelType.cs index 447ec92d..8cead165 100644 --- a/StabilityMatrix.Core/Models/Api/CivitModelType.cs +++ b/StabilityMatrix.Core/Models/Api/CivitModelType.cs @@ -9,6 +9,8 @@ namespace StabilityMatrix.Core.Models.Api; [SuppressMessage("ReSharper", "InconsistentNaming")] public enum CivitModelType { + Unknown, + [ConvertTo(SharedFolderType.StableDiffusion)] Checkpoint, @@ -39,6 +41,5 @@ public enum CivitModelType Wildcards, Workflows, Other, - All, - Unknown + All } diff --git a/StabilityMatrix.Tests/Core/DefaultUnknownEnumConverterTests.cs b/StabilityMatrix.Tests/Core/DefaultUnknownEnumConverterTests.cs new file mode 100644 index 00000000..b0930376 --- /dev/null +++ b/StabilityMatrix.Tests/Core/DefaultUnknownEnumConverterTests.cs @@ -0,0 +1,81 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using StabilityMatrix.Core.Converters.Json; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class DefaultUnknownEnumConverterTests +{ + [TestMethod] + [ExpectedException(typeof(JsonException))] + public void TestDeserialize_NormalEnum_ShouldError() + { + const string json = "\"SomeUnknownValue\""; + + JsonSerializer.Deserialize(json); + } + + [TestMethod] + public void TestDeserialize_UnknownEnum_ShouldConvert() + { + const string json = "\"SomeUnknownValue\""; + + var result = JsonSerializer.Deserialize(json); + + Assert.AreEqual(UnknownEnum.Unknown, result); + } + + [TestMethod] + public void TestDeserialize_DefaultEnum_ShouldConvert() + { + const string json = "\"SomeUnknownValue\""; + + var result = JsonSerializer.Deserialize(json); + + Assert.AreEqual(DefaultEnum.CustomDefault, result); + } + + [TestMethod] + public void TestSerialize_UnknownEnum_ShouldConvert() + { + const string expected = "\"Unknown\""; + + var result = JsonSerializer.Serialize(UnknownEnum.Unknown); + + Assert.AreEqual(expected, result); + } + + [TestMethod] + public void TestSerialize_DefaultEnum_ShouldConvert() + { + const string expected = "\"CustomDefault\""; + + var result = JsonSerializer.Serialize(DefaultEnum.CustomDefault); + + Assert.AreEqual(expected, result); + } + + private enum NormalEnum + { + Unknown, + Value1, + Value2 + } + + [JsonConverter(typeof(DefaultUnknownEnumConverter))] + private enum UnknownEnum + { + Unknown, + Value1, + Value2 + } + + [JsonConverter(typeof(DefaultUnknownEnumConverter))] + private enum DefaultEnum + { + CustomDefault, + Value1, + Value2 + } +}