Ionite
1 year ago
11 changed files with 263 additions and 37 deletions
@ -0,0 +1,6 @@ |
|||||||
|
using OneOf; |
||||||
|
|
||||||
|
namespace StabilityMatrix.Core.Processes; |
||||||
|
|
||||||
|
[GenerateOneOf] |
||||||
|
public partial class Argument : OneOfBase<string, (string, string)> { } |
@ -0,0 +1,75 @@ |
|||||||
|
using System.Diagnostics; |
||||||
|
using OneOf; |
||||||
|
|
||||||
|
namespace StabilityMatrix.Core.Processes; |
||||||
|
|
||||||
|
/// <summary> |
||||||
|
/// Builder for <see cref="ProcessArgs"/>. |
||||||
|
/// </summary> |
||||||
|
public record ProcessArgsBuilder |
||||||
|
{ |
||||||
|
protected ProcessArgsBuilder() { } |
||||||
|
|
||||||
|
public ProcessArgsBuilder(params Argument[] arguments) |
||||||
|
{ |
||||||
|
Arguments = arguments.ToList(); |
||||||
|
} |
||||||
|
|
||||||
|
public List<Argument> Arguments { get; init; } = new(); |
||||||
|
|
||||||
|
private IEnumerable<string> ToStringArgs() |
||||||
|
{ |
||||||
|
foreach (var argument in Arguments) |
||||||
|
{ |
||||||
|
if (argument.IsT0) |
||||||
|
{ |
||||||
|
yield return argument.AsT0; |
||||||
|
} |
||||||
|
else |
||||||
|
{ |
||||||
|
yield return argument.AsT1.Item1; |
||||||
|
yield return argument.AsT1.Item2; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// <inheritdoc /> |
||||||
|
public override string ToString() |
||||||
|
{ |
||||||
|
return ToProcessArgs().ToString(); |
||||||
|
} |
||||||
|
|
||||||
|
public ProcessArgs ToProcessArgs() |
||||||
|
{ |
||||||
|
return ToStringArgs().ToArray(); |
||||||
|
} |
||||||
|
|
||||||
|
public static implicit operator ProcessArgs(ProcessArgsBuilder builder) => |
||||||
|
builder.ToProcessArgs(); |
||||||
|
} |
||||||
|
|
||||||
|
public static class ProcessArgBuilderExtensions |
||||||
|
{ |
||||||
|
public static T AddArg<T>(this T builder, Argument argument) |
||||||
|
where T : ProcessArgsBuilder |
||||||
|
{ |
||||||
|
return builder with { Arguments = builder.Arguments.Append(argument).ToList() }; |
||||||
|
} |
||||||
|
|
||||||
|
public static T RemoveArgKey<T>(this T builder, string argumentKey) |
||||||
|
where T : ProcessArgsBuilder |
||||||
|
{ |
||||||
|
return builder with |
||||||
|
{ |
||||||
|
Arguments = builder.Arguments |
||||||
|
.Where( |
||||||
|
x => |
||||||
|
x.Match( |
||||||
|
stringArg => stringArg != argumentKey, |
||||||
|
tupleArg => tupleArg.Item1 != argumentKey |
||||||
|
) |
||||||
|
) |
||||||
|
.ToList() |
||||||
|
}; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,38 @@ |
|||||||
|
using Semver; |
||||||
|
using StabilityMatrix.Core.Processes; |
||||||
|
|
||||||
|
namespace StabilityMatrix.Core.Python; |
||||||
|
|
||||||
|
public record PipInstallArgs : ProcessArgsBuilder |
||||||
|
{ |
||||||
|
public PipInstallArgs(params Argument[] arguments) |
||||||
|
: base(arguments) { } |
||||||
|
|
||||||
|
public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}"); |
||||||
|
|
||||||
|
public PipInstallArgs WithTorchDirectML(string version = "") => |
||||||
|
this.AddArg($"torch-directml{version}"); |
||||||
|
|
||||||
|
public PipInstallArgs WithTorchVision(string version = "") => |
||||||
|
this.AddArg($"torchvision{version}"); |
||||||
|
|
||||||
|
public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}"); |
||||||
|
|
||||||
|
public PipInstallArgs WithExtraIndex(string indexUrl) => |
||||||
|
this.AddArg(("--extra-index-url", indexUrl)); |
||||||
|
|
||||||
|
public PipInstallArgs WithTorchExtraIndex(string index) => |
||||||
|
this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}")); |
||||||
|
|
||||||
|
public static PipInstallArgs GetTorch(string version = "") => |
||||||
|
new() { Arguments = { $"torch{version}", "torchvision" } }; |
||||||
|
|
||||||
|
public static PipInstallArgs GetTorchDirectML(string version = "") => |
||||||
|
new() { Arguments = { $"torch-directml{version}" } }; |
||||||
|
|
||||||
|
/// <inheritdoc /> |
||||||
|
public override string ToString() |
||||||
|
{ |
||||||
|
return base.ToString(); |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,61 @@ |
|||||||
|
using StabilityMatrix.Core.Processes; |
||||||
|
using StabilityMatrix.Core.Python; |
||||||
|
|
||||||
|
namespace StabilityMatrix.Tests.Core; |
||||||
|
|
||||||
|
[TestClass] |
||||||
|
public class PipInstallArgsTests |
||||||
|
{ |
||||||
|
[TestMethod] |
||||||
|
public void TestGetTorch() |
||||||
|
{ |
||||||
|
// Arrange |
||||||
|
const string version = "==2.1.0"; |
||||||
|
|
||||||
|
// Act |
||||||
|
var args = PipInstallArgs.GetTorch(version).ToProcessArgs().ToString(); |
||||||
|
|
||||||
|
// Assert |
||||||
|
Assert.AreEqual("torch==2.1.0 torchvision", args); |
||||||
|
} |
||||||
|
|
||||||
|
[TestMethod] |
||||||
|
public void TestGetTorchWithExtraIndex() |
||||||
|
{ |
||||||
|
// Arrange |
||||||
|
const string version = ">=2.0.0"; |
||||||
|
const string index = "cu118"; |
||||||
|
|
||||||
|
// Act |
||||||
|
var args = new PipInstallArgs() |
||||||
|
.WithTorch(version) |
||||||
|
.WithTorchVision() |
||||||
|
.WithTorchExtraIndex(index) |
||||||
|
.ToProcessArgs() |
||||||
|
.ToString(); |
||||||
|
|
||||||
|
// Assert |
||||||
|
Assert.AreEqual( |
||||||
|
"torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118", |
||||||
|
args |
||||||
|
); |
||||||
|
} |
||||||
|
|
||||||
|
[TestMethod] |
||||||
|
public void TestGetTorchWithMoreStuff() |
||||||
|
{ |
||||||
|
// Act |
||||||
|
var args = new PipInstallArgs() |
||||||
|
.AddArg("--pre") |
||||||
|
.WithTorch("~=2.0.0") |
||||||
|
.WithTorchVision() |
||||||
|
.WithTorchExtraIndex("nightly/cpu") |
||||||
|
.ToString(); |
||||||
|
|
||||||
|
// Assert |
||||||
|
Assert.AreEqual( |
||||||
|
"--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu", |
||||||
|
args |
||||||
|
); |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue