using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Core.Models;
///
/// Model file union that may be remote or local.
///
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public record HybridModelFile
{
///
/// Singleton instance of that represents use of a default model.
///
public static HybridModelFile Default { get; } = FromRemote("@default");
///
/// Singleton instance of that represents no model.
///
public static HybridModelFile None { get; } = FromRemote("@none");
public string? RemoteName { get; init; }
public LocalModelFile? Local { get; init; }
///
/// Downloadable model information.
///
public RemoteResource? DownloadableResource { get; init; }
public HybridModelType Type { get; init; }
[MemberNotNullWhen(true, nameof(RemoteName))]
[JsonIgnore]
public bool IsRemote => RemoteName != null;
[MemberNotNullWhen(true, nameof(DownloadableResource))]
public bool IsDownloadable => DownloadableResource != null;
[JsonIgnore]
public string RelativePath =>
Type switch
{
HybridModelType.Local => Local!.RelativePathFromSharedFolder,
HybridModelType.Remote => RemoteName!,
HybridModelType.Downloadable => DownloadableResource!.Value.FileName,
HybridModelType.None => throw new InvalidOperationException(),
_ => throw new ArgumentOutOfRangeException()
};
[JsonIgnore]
public string FileName => Path.GetFileName(RelativePath);
[JsonIgnore]
public string ShortDisplayName
{
get
{
if (IsNone)
{
return "None";
}
if (IsDefault)
{
return "Default";
}
if (ReferenceEquals(this, RemoteModels.ControlNetReferenceOnlyModel))
{
return "Reference Only";
}
var fileName = Path.GetFileNameWithoutExtension(RelativePath);
if (
!fileName.Equals("diffusion_pytorch_model", StringComparison.OrdinalIgnoreCase)
&& !fileName.Equals("pytorch_model", StringComparison.OrdinalIgnoreCase)
&& !fileName.Equals("ip_adapter", StringComparison.OrdinalIgnoreCase)
)
{
return Path.GetFileNameWithoutExtension(RelativePath);
}
// show a friendlier name when models have the same name like ip_adapter or diffusion_pytorch_model
var directoryName = Path.GetDirectoryName(RelativePath);
if (directoryName is null)
return Path.GetFileNameWithoutExtension(RelativePath);
var lastIndex = directoryName.LastIndexOf(Path.DirectorySeparatorChar);
if (lastIndex < 0)
return $"{fileName} ({directoryName})";
var parentDirectoryName = directoryName.Substring(lastIndex + 1);
return $"{fileName} ({parentDirectoryName})";
}
}
public static HybridModelFile FromLocal(LocalModelFile local)
{
return new HybridModelFile { Local = local, Type = HybridModelType.Local };
}
public static HybridModelFile FromRemote(string remoteName)
{
return new HybridModelFile { RemoteName = remoteName, Type = HybridModelType.Remote };
}
public static HybridModelFile FromDownloadable(RemoteResource resource)
{
return new HybridModelFile { DownloadableResource = resource, Type = HybridModelType.Downloadable };
}
public string GetId()
{
return $"{RelativePath};{IsNone};{IsDefault}";
}
private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer
{
public bool Equals(HybridModelFile? x, HybridModelFile? y)
{
if (ReferenceEquals(x, y))
return true;
if (ReferenceEquals(x, null))
return false;
if (ReferenceEquals(y, null))
return false;
if (x.GetType() != y.GetType())
return false;
if (!Equals(x.RelativePath, y.RelativePath))
return false;
// This equality affects replacements of remote over local models
// We want local and remote models to be considered equal if they have the same relative path
// But 2 local models with the same path but different config paths should be considered different
return !(x.Type == y.Type && x.Local?.ConfigFullPath != y.Local?.ConfigFullPath);
}
public int GetHashCode(HybridModelFile obj)
{
return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.RelativePath);
}
}
///
/// Whether this instance is the default model.
///
public bool IsDefault => ReferenceEquals(this, Default);
///
/// Whether this instance is no model.
///
public bool IsNone => ReferenceEquals(this, None);
public static IEqualityComparer Comparer { get; } =
new RemoteNameLocalEqualityComparer();
}