Browse Source

Fix direct-ml not supported for fooocus and mre

pull/240/head
Ionite 1 year ago
parent
commit
93749cf3ff
No known key found for this signature in database
  1. 41
      StabilityMatrix.Core/Models/Packages/Fooocus.cs
  2. 41
      StabilityMatrix.Core/Models/Packages/FooocusMre.cs

41
StabilityMatrix.Core/Models/Packages/Fooocus.cs

@ -160,23 +160,32 @@ public class Fooocus : BaseGitPackage
progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
var extraIndex = torchVersion switch
if (torchVersion == TorchVersion.DirectMl)
{
TorchVersion.Cpu => "cpu",
TorchVersion.Cuda => "cu121",
TorchVersion.Rocm => "rocm5.4.2",
_ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
};
await venvRunner
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.1.0")
.WithTorchVision("==0.16.0")
.WithTorchExtraIndex(extraIndex),
onConsoleOutput
)
.ConfigureAwait(false);
await venvRunner
.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
.ConfigureAwait(false);
}
else
{
var extraIndex = torchVersion switch
{
TorchVersion.Cpu => "cpu",
TorchVersion.Cuda => "cu121",
TorchVersion.Rocm => "rocm5.4.2",
_ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
};
await venvRunner
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.1.0")
.WithTorchVision("==0.16.0")
.WithTorchExtraIndex(extraIndex),
onConsoleOutput
)
.ConfigureAwait(false);
}
var requirements = new FilePath(installLocation, "requirements_versions.txt");
await venvRunner

41
StabilityMatrix.Core/Models/Packages/FooocusMre.cs

@ -118,23 +118,32 @@ public class FooocusMre : BaseGitPackage
progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
var extraIndex = torchVersion switch
if (torchVersion == TorchVersion.DirectMl)
{
TorchVersion.Cpu => "cpu",
TorchVersion.Cuda => "cu118",
TorchVersion.Rocm => "rocm5.4.2",
_ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
};
await venvRunner
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision("==0.15.2")
.WithTorchExtraIndex(extraIndex),
onConsoleOutput
)
.ConfigureAwait(false);
await venvRunner
.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
.ConfigureAwait(false);
}
else
{
var extraIndex = torchVersion switch
{
TorchVersion.Cpu => "cpu",
TorchVersion.Cuda => "cu118",
TorchVersion.Rocm => "rocm5.4.2",
_ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
};
await venvRunner
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision("==0.15.2")
.WithTorchExtraIndex(extraIndex),
onConsoleOutput
)
.ConfigureAwait(false);
}
var requirements = new FilePath(installLocation, "requirements_versions.txt");
await venvRunner

Loading…
Cancel
Save