using System;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference;
namespace StabilityMatrix.Avalonia.Models;
///
/// This is the project file for inference tabs
///
[JsonSerializable(typeof(InferenceProjectDocument))]
public class InferenceProjectDocument : ICloneable
{
[JsonIgnore]
private static readonly JsonSerializerOptions SerializerOptions =
new() { IgnoreReadOnlyProperties = true, WriteIndented = true, };
public int Version { get; set; } = 2;
[JsonConverter(typeof(JsonStringEnumConverter))]
public InferenceProjectType ProjectType { get; set; }
public JsonObject? State { get; set; }
public static InferenceProjectDocument FromLoadable(IJsonLoadableState loadableModel)
{
return new InferenceProjectDocument
{
ProjectType = loadableModel switch
{
InferenceImageToImageViewModel => InferenceProjectType.ImageToImage,
InferenceTextToImageViewModel => InferenceProjectType.TextToImage,
InferenceImageUpscaleViewModel => InferenceProjectType.Upscale,
_ => throw new InvalidOperationException($"Unknown loadable model type: {loadableModel.GetType()}")
},
State = loadableModel.SaveStateToJsonObject()
};
}
public void VerifyVersion()
{
if (Version < 2)
{
throw new NotSupportedException(
$"Project was created in an earlier pre-release version of Stability Matrix and is no longer supported. "
+ $"Please create a new project."
);
}
}
public SeedCardModel? GetSeedModel()
{
if (State is null || !State.TryGetPropertyValue("Seed", out var seedCard))
{
return null;
}
return seedCard.Deserialize();
}
///
/// Returns a new with the State modified.
///
/// Action that changes the state
public InferenceProjectDocument WithState(Action stateModifier)
{
var document = (InferenceProjectDocument)Clone();
stateModifier(document.State);
return document;
}
public bool TryUpdateModel(string key, Func modifier)
{
if (State is not { } state)
return false;
if (!state.TryGetPropertyValue(key, out var modelNode))
{
return false;
}
if (modelNode.Deserialize() is not { } model)
{
return false;
}
modelNode = JsonSerializer.SerializeToNode(modifier(model));
state[key] = modelNode;
return true;
}
public bool TryUpdateModel(string key, Func modifier)
{
if (State is not { } state)
return false;
if (!state.TryGetPropertyValue(key, out var modelNode) || modelNode is null)
{
return false;
}
state[key] = modifier(modelNode);
return true;
}
public InferenceProjectDocument WithBatchSize(int batchSize, int batchCount)
{
if (State is null)
throw new InvalidOperationException("State is null");
var document = (InferenceProjectDocument)Clone();
var batchSizeCard =
document.State!["BatchSize"] ?? throw new InvalidOperationException("BatchSize card is null");
batchSizeCard["BatchSize"] = batchSize;
batchSizeCard["BatchCount"] = batchCount;
return document;
}
///
public object Clone()
{
var newObj = (InferenceProjectDocument)MemberwiseClone();
// Clone State also since its mutable
newObj.State = State == null ? null : JsonSerializer.SerializeToNode(State).Deserialize();
return newObj;
}
}