Browse Source

Add Mps recommend in BasePackage

pull/438/head
ionite34 11 months ago
parent
commit
7d247166ee
No known key found for this signature in database
GPG Key ID: B3404C5F3827849B
  1. 35
      StabilityMatrix.Core/Models/Packages/BasePackage.cs

35
StabilityMatrix.Core/Models/Packages/BasePackage.cs

@ -85,11 +85,20 @@ public abstract class BasePackage
public abstract SharedFolderMethod RecommendedSharedFolderMethod { get; }
public abstract Task SetupModelFolders(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod);
public abstract Task SetupModelFolders(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
);
public abstract Task UpdateModelFolders(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod);
public abstract Task UpdateModelFolders(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
);
public abstract Task RemoveModelFolderLinks(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod);
public abstract Task RemoveModelFolderLinks(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
);
public abstract Task SetupOutputFolderLinks(DirectoryPath installDirectory);
public abstract Task RemoveOutputFolderLinks(DirectoryPath installDirectory);
@ -117,6 +126,11 @@ public abstract class BasePackage
return TorchVersion.DirectMl;
}
if (Compat.IsMacOS && Compat.IsArm && AvailableTorchVersions.Contains(TorchVersion.Mps))
{
return TorchVersion.Mps;
}
return TorchVersion.Cpu;
}
@ -142,7 +156,11 @@ public abstract class BasePackage
public abstract Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders { get; }
public abstract Task<PackageVersionOptions> GetAllVersionOptions();
public abstract Task<IEnumerable<GitCommit>?> GetAllCommits(string branch, int page = 1, int perPage = 10);
public abstract Task<IEnumerable<GitCommit>?> GetAllCommits(
string branch,
int page = 1,
int perPage = 10
);
public abstract Task<DownloadPackageVersionOptions> GetLatestVersion(bool includePrerelease = false);
public abstract string MainBranch { get; }
public event EventHandler<int>? Exited;
@ -153,7 +171,9 @@ public abstract class BasePackage
public void OnStartupComplete(string url) => StartupComplete?.Invoke(this, url);
public virtual PackageVersionType AvailableVersionTypes =>
ShouldIgnoreReleases ? PackageVersionType.Commit : PackageVersionType.GithubRelease | PackageVersionType.Commit;
ShouldIgnoreReleases
? PackageVersionType.Commit
: PackageVersionType.GithubRelease | PackageVersionType.Commit;
protected async Task InstallCudaTorch(
PyVenvRunner venvRunner,
@ -194,6 +214,9 @@ public abstract class BasePackage
{
progress?.Report(new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true));
return venvRunner.PipInstall(new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(), onConsoleOutput);
return venvRunner.PipInstall(
new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(),
onConsoleOutput
);
}
}

Loading…
Cancel
Save