From 88635aede30a96e09e63cb099eebf37989b582ea Mon Sep 17 00:00:00 2001 From: Ionite Date: Mon, 23 Oct 2023 20:01:04 -0400 Subject: [PATCH] Add fluent command arguments and fix pip installs --- .../Models/Packages/A3WebUI.cs | 8 +- .../Models/Packages/BasePackage.cs | 13 +++- .../Models/Packages/ComfyUI.cs | 32 ++++++-- .../Models/Packages/InvokeAI.cs | 8 +- .../Models/Packages/StableDiffusionUx.cs | 8 +- StabilityMatrix.Core/Processes/Argument.cs | 6 ++ StabilityMatrix.Core/Processes/ProcessArgs.cs | 28 ++++++- .../Processes/ProcessArgsBuilder.cs | 75 +++++++++++++++++++ StabilityMatrix.Core/Python/PipInstallArgs.cs | 38 ++++++++++ StabilityMatrix.Core/Python/PyVenvRunner.cs | 23 +----- .../Core/PipInstallArgsTests.cs | 61 +++++++++++++++ 11 files changed, 263 insertions(+), 37 deletions(-) create mode 100644 StabilityMatrix.Core/Processes/Argument.cs create mode 100644 StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs create mode 100644 StabilityMatrix.Core/Python/PipInstallArgs.cs create mode 100644 StabilityMatrix.Tests/Core/PipInstallArgsTests.cs diff --git a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs index 665d68c6..44cce164 100644 --- a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs +++ b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs @@ -281,7 +281,13 @@ public class A3WebUI : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.1.1"), + onConsoleOutput + ) .ConfigureAwait(false); } } diff --git a/StabilityMatrix.Core/Models/Packages/BasePackage.cs b/StabilityMatrix.Core/Models/Packages/BasePackage.cs index e1bc0ae1..a7b0b194 100644 --- a/StabilityMatrix.Core/Models/Packages/BasePackage.cs +++ b/StabilityMatrix.Core/Models/Packages/BasePackage.cs @@ -189,9 +189,14 @@ public abstract class BasePackage ); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithXFormers("==0.0.20") + .WithTorchExtraIndex("cu118"), + onConsoleOutput + ) .ConfigureAwait(false); - await venvRunner.PipInstall("xformers==0.0.20", onConsoleOutput).ConfigureAwait(false); } protected Task InstallDirectMlTorch( @@ -204,7 +209,7 @@ public abstract class BasePackage new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true) ); - return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput); + return venvRunner.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput); } protected Task InstallCpuTorch( @@ -217,6 +222,6 @@ public abstract class BasePackage new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true) ); - return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, onConsoleOutput); + return venvRunner.PipInstall(PipInstallArgs.GetTorch("==2.0.1"), onConsoleOutput); } } diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 2fe22438..bcb39eb3 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -179,15 +179,20 @@ public class ComfyUI : BaseGitPackage break; case TorchVersion.Cuda: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsCuda121, onConsoleOutput) - .ConfigureAwait(false); - await venvRunner - .PipInstall("xformers==0.0.22.post4 --upgrade") + .PipInstall( + new PipInstallArgs() + .WithTorch("~=2.1.0") + .WithTorchVision() + .WithXFormers("==0.0.22.post4") + .AddArg("--upgrade") + .WithTorchExtraIndex("cu121"), + onConsoleOutput + ) .ConfigureAwait(false); break; case TorchVersion.DirectMl: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput) + .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput) .ConfigureAwait(false); break; case TorchVersion.Rocm: @@ -195,7 +200,14 @@ public class ComfyUI : BaseGitPackage break; case TorchVersion.Mps: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsNightlyCpu, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .AddArg("--pre") + .WithTorch() + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu"), + onConsoleOutput + ) .ConfigureAwait(false); break; default: @@ -465,7 +477,13 @@ public class ComfyUI : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm56, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.6"), + onConsoleOutput + ) .ConfigureAwait(false); } diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs index ebb4b722..34830a65 100644 --- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs +++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs @@ -182,7 +182,13 @@ public class InvokeAI : BaseGitPackage // For AMD, Install ROCm version case TorchVersion.Rocm: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithExtraIndex("rocm5.4.2"), + onConsoleOutput + ) .ConfigureAwait(false); Logger.Info("Starting InvokeAI install (ROCm)..."); pipCommandArgs = diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs index 5eb3461f..ec8e0052 100644 --- a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs +++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs @@ -263,7 +263,13 @@ public class StableDiffusionUx : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.1.1"), + onConsoleOutput + ) .ConfigureAwait(false); } } diff --git a/StabilityMatrix.Core/Processes/Argument.cs b/StabilityMatrix.Core/Processes/Argument.cs new file mode 100644 index 00000000..44f2c846 --- /dev/null +++ b/StabilityMatrix.Core/Processes/Argument.cs @@ -0,0 +1,6 @@ +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +[GenerateOneOf] +public partial class Argument : OneOfBase { } diff --git a/StabilityMatrix.Core/Processes/ProcessArgs.cs b/StabilityMatrix.Core/Processes/ProcessArgs.cs index 568a7db2..477b5a54 100644 --- a/StabilityMatrix.Core/Processes/ProcessArgs.cs +++ b/StabilityMatrix.Core/Processes/ProcessArgs.cs @@ -9,10 +9,10 @@ namespace StabilityMatrix.Core.Processes; /// Implicitly converts between string and string[], /// with no parsing if the input and output types are the same. /// -public partial class ProcessArgs : OneOfBase +public partial class ProcessArgs : OneOfBase, IEnumerable { /// - private ProcessArgs(OneOf input) + public ProcessArgs(OneOf input) : base(input) { } /// @@ -21,12 +21,36 @@ public partial class ProcessArgs : OneOfBase /// public bool Contains(string arg) => Match(str => str.Contains(arg), arr => arr.Any(Contains)); + public ProcessArgs Concat(ProcessArgs other) => + Match( + str => new ProcessArgs(string.Join(' ', str, other.ToString())), + arr => new ProcessArgs(arr.Concat(other.ToArray()).ToArray()) + ); + + public ProcessArgs Prepend(ProcessArgs other) => + Match( + str => new ProcessArgs(string.Join(' ', other.ToString(), str)), + arr => new ProcessArgs(other.ToArray().Concat(arr).ToArray()) + ); + + /// + public IEnumerator GetEnumerator() + { + return ToArray().AsEnumerable().GetEnumerator(); + } + /// public override string ToString() { return Match(str => str, arr => string.Join(' ', arr.Select(ProcessRunner.Quote))); } + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + public string[] ToArray() => Match( str => ArgumentsRegex().Matches(str).Select(x => x.Value.Trim('"')).ToArray(), diff --git a/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs new file mode 100644 index 00000000..89649564 --- /dev/null +++ b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs @@ -0,0 +1,75 @@ +using System.Diagnostics; +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +/// +/// Builder for . +/// +public record ProcessArgsBuilder +{ + protected ProcessArgsBuilder() { } + + public ProcessArgsBuilder(params Argument[] arguments) + { + Arguments = arguments.ToList(); + } + + public List Arguments { get; init; } = new(); + + private IEnumerable ToStringArgs() + { + foreach (var argument in Arguments) + { + if (argument.IsT0) + { + yield return argument.AsT0; + } + else + { + yield return argument.AsT1.Item1; + yield return argument.AsT1.Item2; + } + } + } + + /// + public override string ToString() + { + return ToProcessArgs().ToString(); + } + + public ProcessArgs ToProcessArgs() + { + return ToStringArgs().ToArray(); + } + + public static implicit operator ProcessArgs(ProcessArgsBuilder builder) => + builder.ToProcessArgs(); +} + +public static class ProcessArgBuilderExtensions +{ + public static T AddArg(this T builder, Argument argument) + where T : ProcessArgsBuilder + { + return builder with { Arguments = builder.Arguments.Append(argument).ToList() }; + } + + public static T RemoveArgKey(this T builder, string argumentKey) + where T : ProcessArgsBuilder + { + return builder with + { + Arguments = builder.Arguments + .Where( + x => + x.Match( + stringArg => stringArg != argumentKey, + tupleArg => tupleArg.Item1 != argumentKey + ) + ) + .ToList() + }; + } +} diff --git a/StabilityMatrix.Core/Python/PipInstallArgs.cs b/StabilityMatrix.Core/Python/PipInstallArgs.cs new file mode 100644 index 00000000..84620484 --- /dev/null +++ b/StabilityMatrix.Core/Python/PipInstallArgs.cs @@ -0,0 +1,38 @@ +using Semver; +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Core.Python; + +public record PipInstallArgs : ProcessArgsBuilder +{ + public PipInstallArgs(params Argument[] arguments) + : base(arguments) { } + + public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}"); + + public PipInstallArgs WithTorchDirectML(string version = "") => + this.AddArg($"torch-directml{version}"); + + public PipInstallArgs WithTorchVision(string version = "") => + this.AddArg($"torchvision{version}"); + + public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}"); + + public PipInstallArgs WithExtraIndex(string indexUrl) => + this.AddArg(("--extra-index-url", indexUrl)); + + public PipInstallArgs WithTorchExtraIndex(string index) => + this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}")); + + public static PipInstallArgs GetTorch(string version = "") => + new() { Arguments = { $"torch{version}", "torchvision" } }; + + public static PipInstallArgs GetTorchDirectML(string version = "") => + new() { Arguments = { $"torch-directml{version}" } }; + + /// + public override string ToString() + { + return base.ToString(); + } +} diff --git a/StabilityMatrix.Core/Python/PyVenvRunner.cs b/StabilityMatrix.Core/Python/PyVenvRunner.cs index 0e7c7419..9b288ad5 100644 --- a/StabilityMatrix.Core/Python/PyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/PyVenvRunner.cs @@ -19,25 +19,6 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - private const string TorchPipInstallArgs = "torch==2.0.1 torchvision"; - - public const string TorchPipInstallArgsCuda = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/cu118"; - public const string TorchPipInstallArgsCuda121 = - "torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121"; - public const string TorchPipInstallArgsCpu = TorchPipInstallArgs; - public const string TorchPipInstallArgsDirectML = "torch-directml"; - - public const string TorchPipInstallArgsRocm511 = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.1.1"; - public const string TorchPipInstallArgsRocm542 = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; - public const string TorchPipInstallArgsRocm56 = - $"{TorchPipInstallArgs} --index-url https://download.pytorch.org/whl/rocm5.6"; - - public const string TorchPipInstallArgsNightlyCpu = - "--pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu"; - /// /// Relative path to the site-packages folder from the venv root. /// This is platform specific. @@ -216,7 +197,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable /// Run a pip install command. Waits for the process to exit. /// workingDirectory defaults to RootPath. /// - public async Task PipInstall(string args, Action? outputDataReceived = null) + public async Task PipInstall(ProcessArgs args, Action? outputDataReceived = null) { if (!File.Exists(PipPath)) { @@ -236,7 +217,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable }); SetPyvenvCfg(PyRunner.PythonDir); - RunDetached($"-m pip install {args}", outputAction); + RunDetached(args.Prepend("-m pip install"), outputAction); await Process.WaitForExitAsync().ConfigureAwait(false); // Check return code diff --git a/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs new file mode 100644 index 00000000..5ffd3e29 --- /dev/null +++ b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs @@ -0,0 +1,61 @@ +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class PipInstallArgsTests +{ + [TestMethod] + public void TestGetTorch() + { + // Arrange + const string version = "==2.1.0"; + + // Act + var args = PipInstallArgs.GetTorch(version).ToProcessArgs().ToString(); + + // Assert + Assert.AreEqual("torch==2.1.0 torchvision", args); + } + + [TestMethod] + public void TestGetTorchWithExtraIndex() + { + // Arrange + const string version = ">=2.0.0"; + const string index = "cu118"; + + // Act + var args = new PipInstallArgs() + .WithTorch(version) + .WithTorchVision() + .WithTorchExtraIndex(index) + .ToProcessArgs() + .ToString(); + + // Assert + Assert.AreEqual( + "torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118", + args + ); + } + + [TestMethod] + public void TestGetTorchWithMoreStuff() + { + // Act + var args = new PipInstallArgs() + .AddArg("--pre") + .WithTorch("~=2.0.0") + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu") + .ToString(); + + // Assert + Assert.AreEqual( + "--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu", + args + ); + } +}