Ionite
1 year ago
11 changed files with 263 additions and 37 deletions
@ -0,0 +1,6 @@
|
||||
using OneOf; |
||||
|
||||
namespace StabilityMatrix.Core.Processes; |
||||
|
||||
[GenerateOneOf] |
||||
public partial class Argument : OneOfBase<string, (string, string)> { } |
@ -0,0 +1,75 @@
|
||||
using System.Diagnostics; |
||||
using OneOf; |
||||
|
||||
namespace StabilityMatrix.Core.Processes; |
||||
|
||||
/// <summary> |
||||
/// Builder for <see cref="ProcessArgs"/>. |
||||
/// </summary> |
||||
public record ProcessArgsBuilder |
||||
{ |
||||
protected ProcessArgsBuilder() { } |
||||
|
||||
public ProcessArgsBuilder(params Argument[] arguments) |
||||
{ |
||||
Arguments = arguments.ToList(); |
||||
} |
||||
|
||||
public List<Argument> Arguments { get; init; } = new(); |
||||
|
||||
private IEnumerable<string> ToStringArgs() |
||||
{ |
||||
foreach (var argument in Arguments) |
||||
{ |
||||
if (argument.IsT0) |
||||
{ |
||||
yield return argument.AsT0; |
||||
} |
||||
else |
||||
{ |
||||
yield return argument.AsT1.Item1; |
||||
yield return argument.AsT1.Item2; |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// <inheritdoc /> |
||||
public override string ToString() |
||||
{ |
||||
return ToProcessArgs().ToString(); |
||||
} |
||||
|
||||
public ProcessArgs ToProcessArgs() |
||||
{ |
||||
return ToStringArgs().ToArray(); |
||||
} |
||||
|
||||
public static implicit operator ProcessArgs(ProcessArgsBuilder builder) => |
||||
builder.ToProcessArgs(); |
||||
} |
||||
|
||||
public static class ProcessArgBuilderExtensions |
||||
{ |
||||
public static T AddArg<T>(this T builder, Argument argument) |
||||
where T : ProcessArgsBuilder |
||||
{ |
||||
return builder with { Arguments = builder.Arguments.Append(argument).ToList() }; |
||||
} |
||||
|
||||
public static T RemoveArgKey<T>(this T builder, string argumentKey) |
||||
where T : ProcessArgsBuilder |
||||
{ |
||||
return builder with |
||||
{ |
||||
Arguments = builder.Arguments |
||||
.Where( |
||||
x => |
||||
x.Match( |
||||
stringArg => stringArg != argumentKey, |
||||
tupleArg => tupleArg.Item1 != argumentKey |
||||
) |
||||
) |
||||
.ToList() |
||||
}; |
||||
} |
||||
} |
@ -0,0 +1,38 @@
|
||||
using Semver; |
||||
using StabilityMatrix.Core.Processes; |
||||
|
||||
namespace StabilityMatrix.Core.Python; |
||||
|
||||
public record PipInstallArgs : ProcessArgsBuilder |
||||
{ |
||||
public PipInstallArgs(params Argument[] arguments) |
||||
: base(arguments) { } |
||||
|
||||
public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}"); |
||||
|
||||
public PipInstallArgs WithTorchDirectML(string version = "") => |
||||
this.AddArg($"torch-directml{version}"); |
||||
|
||||
public PipInstallArgs WithTorchVision(string version = "") => |
||||
this.AddArg($"torchvision{version}"); |
||||
|
||||
public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}"); |
||||
|
||||
public PipInstallArgs WithExtraIndex(string indexUrl) => |
||||
this.AddArg(("--extra-index-url", indexUrl)); |
||||
|
||||
public PipInstallArgs WithTorchExtraIndex(string index) => |
||||
this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}")); |
||||
|
||||
public static PipInstallArgs GetTorch(string version = "") => |
||||
new() { Arguments = { $"torch{version}", "torchvision" } }; |
||||
|
||||
public static PipInstallArgs GetTorchDirectML(string version = "") => |
||||
new() { Arguments = { $"torch-directml{version}" } }; |
||||
|
||||
/// <inheritdoc /> |
||||
public override string ToString() |
||||
{ |
||||
return base.ToString(); |
||||
} |
||||
} |
@ -0,0 +1,61 @@
|
||||
using StabilityMatrix.Core.Processes; |
||||
using StabilityMatrix.Core.Python; |
||||
|
||||
namespace StabilityMatrix.Tests.Core; |
||||
|
||||
[TestClass] |
||||
public class PipInstallArgsTests |
||||
{ |
||||
[TestMethod] |
||||
public void TestGetTorch() |
||||
{ |
||||
// Arrange |
||||
const string version = "==2.1.0"; |
||||
|
||||
// Act |
||||
var args = PipInstallArgs.GetTorch(version).ToProcessArgs().ToString(); |
||||
|
||||
// Assert |
||||
Assert.AreEqual("torch==2.1.0 torchvision", args); |
||||
} |
||||
|
||||
[TestMethod] |
||||
public void TestGetTorchWithExtraIndex() |
||||
{ |
||||
// Arrange |
||||
const string version = ">=2.0.0"; |
||||
const string index = "cu118"; |
||||
|
||||
// Act |
||||
var args = new PipInstallArgs() |
||||
.WithTorch(version) |
||||
.WithTorchVision() |
||||
.WithTorchExtraIndex(index) |
||||
.ToProcessArgs() |
||||
.ToString(); |
||||
|
||||
// Assert |
||||
Assert.AreEqual( |
||||
"torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118", |
||||
args |
||||
); |
||||
} |
||||
|
||||
[TestMethod] |
||||
public void TestGetTorchWithMoreStuff() |
||||
{ |
||||
// Act |
||||
var args = new PipInstallArgs() |
||||
.AddArg("--pre") |
||||
.WithTorch("~=2.0.0") |
||||
.WithTorchVision() |
||||
.WithTorchExtraIndex("nightly/cpu") |
||||
.ToString(); |
||||
|
||||
// Assert |
||||
Assert.AreEqual( |
||||
"--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu", |
||||
args |
||||
); |
||||
} |
||||
} |
Loading…
Reference in new issue