Browse Source

Add fluent command arguments and fix pip installs

pull/240/head
Ionite 1 year ago
parent
commit
88635aede3
No known key found for this signature in database
  1. 8
      StabilityMatrix.Core/Models/Packages/A3WebUI.cs
  2. 13
      StabilityMatrix.Core/Models/Packages/BasePackage.cs
  3. 32
      StabilityMatrix.Core/Models/Packages/ComfyUI.cs
  4. 8
      StabilityMatrix.Core/Models/Packages/InvokeAI.cs
  5. 8
      StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs
  6. 6
      StabilityMatrix.Core/Processes/Argument.cs
  7. 28
      StabilityMatrix.Core/Processes/ProcessArgs.cs
  8. 75
      StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs
  9. 38
      StabilityMatrix.Core/Python/PipInstallArgs.cs
  10. 23
      StabilityMatrix.Core/Python/PyVenvRunner.cs
  11. 61
      StabilityMatrix.Tests/Core/PipInstallArgsTests.cs

8
StabilityMatrix.Core/Models/Packages/A3WebUI.cs

@ -281,7 +281,13 @@ public class A3WebUI : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithTorchExtraIndex("rocm5.1.1"),
onConsoleOutput
)
.ConfigureAwait(false);
}
}

13
StabilityMatrix.Core/Models/Packages/BasePackage.cs

@ -189,9 +189,14 @@ public abstract class BasePackage
);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithXFormers("==0.0.20")
.WithTorchExtraIndex("cu118"),
onConsoleOutput
)
.ConfigureAwait(false);
await venvRunner.PipInstall("xformers==0.0.20", onConsoleOutput).ConfigureAwait(false);
}
protected Task InstallDirectMlTorch(
@ -204,7 +209,7 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true)
);
return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput);
return venvRunner.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput);
}
protected Task InstallCpuTorch(
@ -217,6 +222,6 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true)
);
return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, onConsoleOutput);
return venvRunner.PipInstall(PipInstallArgs.GetTorch("==2.0.1"), onConsoleOutput);
}
}

32
StabilityMatrix.Core/Models/Packages/ComfyUI.cs

@ -179,15 +179,20 @@ public class ComfyUI : BaseGitPackage
break;
case TorchVersion.Cuda:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda121, onConsoleOutput)
.ConfigureAwait(false);
await venvRunner
.PipInstall("xformers==0.0.22.post4 --upgrade")
.PipInstall(
new PipInstallArgs()
.WithTorch("~=2.1.0")
.WithTorchVision()
.WithXFormers("==0.0.22.post4")
.AddArg("--upgrade")
.WithTorchExtraIndex("cu121"),
onConsoleOutput
)
.ConfigureAwait(false);
break;
case TorchVersion.DirectMl:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput)
.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
.ConfigureAwait(false);
break;
case TorchVersion.Rocm:
@ -195,7 +200,14 @@ public class ComfyUI : BaseGitPackage
break;
case TorchVersion.Mps:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsNightlyCpu, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.AddArg("--pre")
.WithTorch()
.WithTorchVision()
.WithTorchExtraIndex("nightly/cpu"),
onConsoleOutput
)
.ConfigureAwait(false);
break;
default:
@ -465,7 +477,13 @@ public class ComfyUI : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm56, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithTorchExtraIndex("rocm5.6"),
onConsoleOutput
)
.ConfigureAwait(false);
}

8
StabilityMatrix.Core/Models/Packages/InvokeAI.cs

@ -182,7 +182,13 @@ public class InvokeAI : BaseGitPackage
// For AMD, Install ROCm version
case TorchVersion.Rocm:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithExtraIndex("rocm5.4.2"),
onConsoleOutput
)
.ConfigureAwait(false);
Logger.Info("Starting InvokeAI install (ROCm)...");
pipCommandArgs =

8
StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs

@ -263,7 +263,13 @@ public class StableDiffusionUx : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithTorchExtraIndex("rocm5.1.1"),
onConsoleOutput
)
.ConfigureAwait(false);
}
}

6
StabilityMatrix.Core/Processes/Argument.cs

@ -0,0 +1,6 @@
using OneOf;
namespace StabilityMatrix.Core.Processes;
[GenerateOneOf]
public partial class Argument : OneOfBase<string, (string, string)> { }

28
StabilityMatrix.Core/Processes/ProcessArgs.cs

@ -9,10 +9,10 @@ namespace StabilityMatrix.Core.Processes;
/// Implicitly converts between string and string[],
/// with no parsing if the input and output types are the same.
/// </summary>
public partial class ProcessArgs : OneOfBase<string, string[]>
public partial class ProcessArgs : OneOfBase<string, string[]>, IEnumerable<string>
{
/// <inheritdoc />
private ProcessArgs(OneOf<string, string[]> input)
public ProcessArgs(OneOf<string, string[]> input)
: base(input) { }
/// <summary>
@ -21,12 +21,36 @@ public partial class ProcessArgs : OneOfBase<string, string[]>
/// </summary>
public bool Contains(string arg) => Match(str => str.Contains(arg), arr => arr.Any(Contains));
public ProcessArgs Concat(ProcessArgs other) =>
Match(
str => new ProcessArgs(string.Join(' ', str, other.ToString())),
arr => new ProcessArgs(arr.Concat(other.ToArray()).ToArray())
);
public ProcessArgs Prepend(ProcessArgs other) =>
Match(
str => new ProcessArgs(string.Join(' ', other.ToString(), str)),
arr => new ProcessArgs(other.ToArray().Concat(arr).ToArray())
);
/// <inheritdoc />
public IEnumerator<string> GetEnumerator()
{
return ToArray().AsEnumerable().GetEnumerator();
}
/// <inheritdoc />
public override string ToString()
{
return Match(str => str, arr => string.Join(' ', arr.Select(ProcessRunner.Quote)));
}
/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public string[] ToArray() =>
Match(
str => ArgumentsRegex().Matches(str).Select(x => x.Value.Trim('"')).ToArray(),

75
StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs

@ -0,0 +1,75 @@
using System.Diagnostics;
using OneOf;
namespace StabilityMatrix.Core.Processes;
/// <summary>
/// Builder for <see cref="ProcessArgs"/>.
/// </summary>
public record ProcessArgsBuilder
{
protected ProcessArgsBuilder() { }
public ProcessArgsBuilder(params Argument[] arguments)
{
Arguments = arguments.ToList();
}
public List<Argument> Arguments { get; init; } = new();
private IEnumerable<string> ToStringArgs()
{
foreach (var argument in Arguments)
{
if (argument.IsT0)
{
yield return argument.AsT0;
}
else
{
yield return argument.AsT1.Item1;
yield return argument.AsT1.Item2;
}
}
}
/// <inheritdoc />
public override string ToString()
{
return ToProcessArgs().ToString();
}
public ProcessArgs ToProcessArgs()
{
return ToStringArgs().ToArray();
}
public static implicit operator ProcessArgs(ProcessArgsBuilder builder) =>
builder.ToProcessArgs();
}
public static class ProcessArgBuilderExtensions
{
public static T AddArg<T>(this T builder, Argument argument)
where T : ProcessArgsBuilder
{
return builder with { Arguments = builder.Arguments.Append(argument).ToList() };
}
public static T RemoveArgKey<T>(this T builder, string argumentKey)
where T : ProcessArgsBuilder
{
return builder with
{
Arguments = builder.Arguments
.Where(
x =>
x.Match(
stringArg => stringArg != argumentKey,
tupleArg => tupleArg.Item1 != argumentKey
)
)
.ToList()
};
}
}

38
StabilityMatrix.Core/Python/PipInstallArgs.cs

@ -0,0 +1,38 @@
using Semver;
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Python;
public record PipInstallArgs : ProcessArgsBuilder
{
public PipInstallArgs(params Argument[] arguments)
: base(arguments) { }
public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}");
public PipInstallArgs WithTorchDirectML(string version = "") =>
this.AddArg($"torch-directml{version}");
public PipInstallArgs WithTorchVision(string version = "") =>
this.AddArg($"torchvision{version}");
public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}");
public PipInstallArgs WithExtraIndex(string indexUrl) =>
this.AddArg(("--extra-index-url", indexUrl));
public PipInstallArgs WithTorchExtraIndex(string index) =>
this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}"));
public static PipInstallArgs GetTorch(string version = "") =>
new() { Arguments = { $"torch{version}", "torchvision" } };
public static PipInstallArgs GetTorchDirectML(string version = "") =>
new() { Arguments = { $"torch-directml{version}" } };
/// <inheritdoc />
public override string ToString()
{
return base.ToString();
}
}

23
StabilityMatrix.Core/Python/PyVenvRunner.cs

@ -19,25 +19,6 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private const string TorchPipInstallArgs = "torch==2.0.1 torchvision";
public const string TorchPipInstallArgsCuda =
$"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/cu118";
public const string TorchPipInstallArgsCuda121 =
"torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121";
public const string TorchPipInstallArgsCpu = TorchPipInstallArgs;
public const string TorchPipInstallArgsDirectML = "torch-directml";
public const string TorchPipInstallArgsRocm511 =
$"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.1.1";
public const string TorchPipInstallArgsRocm542 =
$"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.4.2";
public const string TorchPipInstallArgsRocm56 =
$"{TorchPipInstallArgs} --index-url https://download.pytorch.org/whl/rocm5.6";
public const string TorchPipInstallArgsNightlyCpu =
"--pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu";
/// <summary>
/// Relative path to the site-packages folder from the venv root.
/// This is platform specific.
@ -216,7 +197,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
/// Run a pip install command. Waits for the process to exit.
/// workingDirectory defaults to RootPath.
/// </summary>
public async Task PipInstall(string args, Action<ProcessOutput>? outputDataReceived = null)
public async Task PipInstall(ProcessArgs args, Action<ProcessOutput>? outputDataReceived = null)
{
if (!File.Exists(PipPath))
{
@ -236,7 +217,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
});
SetPyvenvCfg(PyRunner.PythonDir);
RunDetached($"-m pip install {args}", outputAction);
RunDetached(args.Prepend("-m pip install"), outputAction);
await Process.WaitForExitAsync().ConfigureAwait(false);
// Check return code

61
StabilityMatrix.Tests/Core/PipInstallArgsTests.cs

@ -0,0 +1,61 @@
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
namespace StabilityMatrix.Tests.Core;
[TestClass]
public class PipInstallArgsTests
{
[TestMethod]
public void TestGetTorch()
{
// Arrange
const string version = "==2.1.0";
// Act
var args = PipInstallArgs.GetTorch(version).ToProcessArgs().ToString();
// Assert
Assert.AreEqual("torch==2.1.0 torchvision", args);
}
[TestMethod]
public void TestGetTorchWithExtraIndex()
{
// Arrange
const string version = ">=2.0.0";
const string index = "cu118";
// Act
var args = new PipInstallArgs()
.WithTorch(version)
.WithTorchVision()
.WithTorchExtraIndex(index)
.ToProcessArgs()
.ToString();
// Assert
Assert.AreEqual(
"torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118",
args
);
}
[TestMethod]
public void TestGetTorchWithMoreStuff()
{
// Act
var args = new PipInstallArgs()
.AddArg("--pre")
.WithTorch("~=2.0.0")
.WithTorchVision()
.WithTorchExtraIndex("nightly/cpu")
.ToString();
// Assert
Assert.AreEqual(
"--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu",
args
);
}
}
Loading…
Cancel
Save