Browse Source

Change static calls to instance for PipInstallArgs

pull/240/head
Ionite 1 year ago
parent
commit
b7a2fb58bd
No known key found for this signature in database
  1. 5
      StabilityMatrix.Core/Models/Packages/BasePackage.cs
  2. 9
      StabilityMatrix.Core/Python/PipInstallArgs.cs
  3. 4
      StabilityMatrix.Tests/Core/PipInstallArgsTests.cs

5
StabilityMatrix.Core/Models/Packages/BasePackage.cs

@ -222,6 +222,9 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true) new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true)
); );
return venvRunner.PipInstall(PipInstallArgs.GetTorch("==2.0.1"), onConsoleOutput); return venvRunner.PipInstall(
new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(),
onConsoleOutput
);
} }
} }

9
StabilityMatrix.Core/Python/PipInstallArgs.cs

@ -1,5 +1,4 @@
using Semver; using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Python; namespace StabilityMatrix.Core.Python;
@ -24,12 +23,6 @@ public record PipInstallArgs : ProcessArgsBuilder
public PipInstallArgs WithTorchExtraIndex(string index) => public PipInstallArgs WithTorchExtraIndex(string index) =>
this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}")); this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}"));
public static PipInstallArgs GetTorch(string version = "") =>
new() { Arguments = { $"torch{version}", "torchvision" } };
public static PipInstallArgs GetTorchDirectML(string version = "") =>
new() { Arguments = { $"torch-directml{version}" } };
/// <inheritdoc /> /// <inheritdoc />
public override string ToString() public override string ToString()
{ {

4
StabilityMatrix.Tests/Core/PipInstallArgsTests.cs

@ -13,10 +13,10 @@ public class PipInstallArgsTests
const string version = "==2.1.0"; const string version = "==2.1.0";
// Act // Act
var args = PipInstallArgs.GetTorch(version).ToProcessArgs().ToString(); var args = new PipInstallArgs().WithTorch(version).ToProcessArgs().ToString();
// Assert // Assert
Assert.AreEqual("torch==2.1.0 torchvision", args); Assert.AreEqual("torch==2.1.0", args);
} }
[TestMethod] [TestMethod]

Loading…
Cancel
Save