diff --git a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs index e0e913a0..fafe8b49 100644 --- a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.Versioning; @@ -240,7 +241,8 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper public async Task RunNpm( ProcessArgs args, string? workingDirectory = null, - Action? onProcessOutput = null + Action? onProcessOutput = null, + IReadOnlyDictionary? envVars = null ) { var command = args.Prepend([NpmPath]); diff --git a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs index e4da00f3..e42c7370 100644 --- a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs @@ -115,11 +115,12 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper public async Task RunNpm( ProcessArgs args, string? workingDirectory = null, - Action? onProcessOutput = null + Action? onProcessOutput = null, + IReadOnlyDictionary? envVars = null ) { var result = await ProcessRunner - .GetProcessResultAsync(NodeExistsPath, args, workingDirectory) + .GetProcessResultAsync(NodeExistsPath, args, workingDirectory, envVars) .ConfigureAwait(false); result.EnsureSuccessExitCode(); diff --git a/StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs b/StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs index c01f3e88..a04758cd 100644 --- a/StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs +++ b/StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs @@ -157,7 +157,8 @@ public interface IPrerequisiteHelper Task RunNpm( ProcessArgs args, string? workingDirectory = null, - Action? onProcessOutput = null + Action? onProcessOutput = null, + IReadOnlyDictionary? envVars = null ); Task InstallNodeIfNecessary(IProgress? progress = null); } diff --git a/StabilityMatrix.Core/Helper/PrerequisiteHelper.cs b/StabilityMatrix.Core/Helper/PrerequisiteHelper.cs index a32205b3..e4902325 100644 --- a/StabilityMatrix.Core/Helper/PrerequisiteHelper.cs +++ b/StabilityMatrix.Core/Helper/PrerequisiteHelper.cs @@ -112,7 +112,8 @@ public class PrerequisiteHelper : IPrerequisiteHelper public Task RunNpm( ProcessArgs args, string? workingDirectory = null, - Action? onProcessOutput = null + Action? onProcessOutput = null, + IReadOnlyDictionary? envVars = null ) { throw new NotImplementedException(); diff --git a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs index 79a0d0fe..75130ea8 100644 --- a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs +++ b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs @@ -148,7 +148,8 @@ public class A3WebUI( Name = "No Half", Type = LaunchOptionType.Bool, Description = "Do not switch the model to 16-bit floats", - InitialValue = HardwareHelper.PreferRocm() || HardwareHelper.PreferDirectML(), + InitialValue = + HardwareHelper.PreferRocm() || HardwareHelper.PreferDirectML() || Compat.IsMacOS, Options = ["--no-half"] }, new() @@ -171,7 +172,7 @@ public class A3WebUI( new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm }; + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm, TorchVersion.Mps }; public override string MainBranch => "master"; @@ -208,6 +209,7 @@ public class A3WebUI( TorchVersion.Cpu => "cpu", TorchVersion.Cuda => "cu121", TorchVersion.Rocm => "rocm5.6", + TorchVersion.Mps => "nightly/cpu", _ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null) } ) diff --git a/StabilityMatrix.Core/Models/Packages/BasePackage.cs b/StabilityMatrix.Core/Models/Packages/BasePackage.cs index 2972277b..2aeeb57a 100644 --- a/StabilityMatrix.Core/Models/Packages/BasePackage.cs +++ b/StabilityMatrix.Core/Models/Packages/BasePackage.cs @@ -201,10 +201,10 @@ public abstract class BasePackage await venvRunner .PipInstall( new PipInstallArgs() - .WithTorch("==2.0.1") - .WithTorchVision("==0.15.2") - .WithXFormers("==0.0.20") - .WithTorchExtraIndex("cu118"), + .WithTorch("==2.1.2") + .WithTorchVision("==0.16.2") + .WithXFormers("==0.0.23post1") + .WithTorchExtraIndex("cu121"), onConsoleOutput ) .ConfigureAwait(false); @@ -230,7 +230,7 @@ public abstract class BasePackage progress?.Report(new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true)); return venvRunner.PipInstall( - new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(), + new PipInstallArgs().WithTorch("==2.1.2").WithTorchVision(), onConsoleOutput ); } diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs index 6118149b..f71eb598 100644 --- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs +++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs @@ -283,11 +283,15 @@ public class InvokeAI : BaseGitPackage ) { await PrerequisiteHelper.InstallNodeIfNecessary(progress).ConfigureAwait(false); - await PrerequisiteHelper.RunNpm(["i", "pnpm"], installLocation).ConfigureAwait(false); + await PrerequisiteHelper + .RunNpm(["i", "pnpm"], installLocation, envVars: envVars) + .ConfigureAwait(false); if (Compat.IsMacOS || Compat.IsLinux) { - await PrerequisiteHelper.RunNpm(["i", "vite"], installLocation).ConfigureAwait(false); + await PrerequisiteHelper + .RunNpm(["i", "vite"], installLocation, envVars: envVars) + .ConfigureAwait(false); } var pnpmPath = Path.Combine( diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs index 8da91d96..1b161ac9 100644 --- a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs +++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs @@ -142,7 +142,8 @@ public class StableDiffusionUx( Name = "No Half", Type = LaunchOptionType.Bool, Description = "Do not switch the model to 16-bit floats", - InitialValue = HardwareHelper.PreferRocm() || HardwareHelper.PreferDirectML(), + InitialValue = + HardwareHelper.PreferRocm() || HardwareHelper.PreferDirectML() || Compat.IsMacOS, Options = ["--no-half"] }, new() @@ -165,7 +166,7 @@ public class StableDiffusionUx( new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm }; + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm, TorchVersion.Mps }; public override string MainBranch => "master"; @@ -199,6 +200,17 @@ public class StableDiffusionUx( case TorchVersion.Rocm: await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); break; + case TorchVersion.Mps: + await venvRunner + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.1.2") + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu"), + onConsoleOutput + ) + .ConfigureAwait(false); + break; } // Install requirements file diff --git a/StabilityMatrix.Core/Processes/ProcessRunner.cs b/StabilityMatrix.Core/Processes/ProcessRunner.cs index aa2ecc25..2da50556 100644 --- a/StabilityMatrix.Core/Processes/ProcessRunner.cs +++ b/StabilityMatrix.Core/Processes/ProcessRunner.cs @@ -356,7 +356,11 @@ public static class ProcessRunner return StartAnsiProcess(fileName, args, workingDirectory, outputDataReceived, environmentVariables); } - public static async Task RunBashCommand(string command, string workingDirectory = "") + public static async Task RunBashCommand( + string command, + string workingDirectory = "", + IReadOnlyDictionary? environmentVariables = null + ) { // Escape any single quotes in the command var escapedCommand = command.Replace("\"", "\\\""); @@ -372,6 +376,14 @@ public static class ProcessRunner WorkingDirectory = workingDirectory, }; + if (environmentVariables != null) + { + foreach (var (key, value) in environmentVariables) + { + processInfo.EnvironmentVariables[key] = value; + } + } + using var process = new Process(); process.StartInfo = processInfo; @@ -396,12 +408,13 @@ public static class ProcessRunner public static Task RunBashCommand( IEnumerable commands, - string workingDirectory = "" + string workingDirectory = "", + IReadOnlyDictionary? environmentVariables = null ) { // Quote arguments containing spaces var args = string.Join(" ", commands.Select(Quote)); - return RunBashCommand(args, workingDirectory); + return RunBashCommand(args, workingDirectory, environmentVariables); } ///