Browse Source

Add VoltaML package

pull/55/head
Ionite 1 year ago
parent
commit
bd349a74cc
No known key found for this signature in database
  1. 1
      StabilityMatrix.Avalonia/App.axaml.cs
  2. 23
      StabilityMatrix.Core/Models/Packages/A3WebUI.cs
  3. 64
      StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
  4. 2
      StabilityMatrix.Core/Models/Packages/VladAutomatic.cs
  5. 151
      StabilityMatrix.Core/Models/Packages/VoltaML.cs

1
StabilityMatrix.Avalonia/App.axaml.cs

@ -291,6 +291,7 @@ public sealed class App : Application
services.AddSingleton<BasePackage, A3WebUI>();
services.AddSingleton<BasePackage, VladAutomatic>();
services.AddSingleton<BasePackage, ComfyUI>();
services.AddSingleton<BasePackage, VoltaML>();
}
private static IServiceCollection ConfigureServices()

23
StabilityMatrix.Core/Models/Packages/A3WebUI.cs

@ -1,6 +1,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text.RegularExpressions;
using NLog;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.Progress;
@ -12,6 +13,8 @@ namespace StabilityMatrix.Core.Models.Packages;
public class A3WebUI : BaseGitPackage
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
public override string Name => "stable-diffusion-webui";
public override string DisplayName { get; set; } = "Stable Diffusion WebUI";
public override string Author => "AUTOMATIC1111";
@ -110,28 +113,10 @@ public class A3WebUI : BaseGitPackage
public override async Task<string> GetLatestVersion()
{
var release = await GetLatestRelease();
var release = await GetLatestRelease().ConfigureAwait(false);
return release.TagName!;
}
public override async Task<IEnumerable<PackageVersion>> GetAllVersions(bool isReleaseMode = true)
{
if (isReleaseMode)
{
var allReleases = await GetAllReleases();
return allReleases.Where(r => r.Prerelease == false).Select(r => new PackageVersion
{TagName = r.TagName!, ReleaseNotesMarkdown = r.Body});
}
// else, branch mode
var allBranches = await GetAllBranches();
return allBranches.Select(b => new PackageVersion
{
TagName = $"{b.Name}",
ReleaseNotesMarkdown = string.Empty
});
}
public override async Task InstallPackage(IProgress<ProgressReport>? progress = null)
{
await UnzipPackage(progress);

64
StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs

@ -16,9 +16,11 @@ namespace StabilityMatrix.Core.Models.Packages;
/// Author and Name should be the Github username and repository name respectively.
/// </summary>
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
public abstract class BaseGitPackage : BasePackage
{
protected static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
protected readonly IGithubApiCache GithubApi;
protected readonly ISettingsManager SettingsManager;
protected readonly IDownloadService DownloadService;
@ -55,7 +57,9 @@ public abstract class BaseGitPackage : BasePackage
protected async Task<Release> GetLatestRelease(bool includePrerelease = false)
{
var releases = await GithubApi.GetAllReleases(Author, Name);
var releases = await GithubApi
.GetAllReleases(Author, Name)
.ConfigureAwait(false);
return includePrerelease ? releases.First() : releases.First(x => !x.Prerelease);
}
@ -73,6 +77,29 @@ public abstract class BaseGitPackage : BasePackage
{
return GithubApi.GetAllCommits(Author, Name, branch, page, perPage);
}
public override async Task<IEnumerable<PackageVersion>> GetAllVersions(bool isReleaseMode = true)
{
// Release mode
if (isReleaseMode)
{
var allReleases = await GetAllReleases().ConfigureAwait(false);
return allReleases.Where(r => r.Prerelease == false).Select(r =>
new PackageVersion
{
TagName = r.TagName!,
ReleaseNotesMarkdown = r.Body
});
}
// Branch mode
var allBranches = await GetAllBranches().ConfigureAwait(false);
return allBranches.Select(b => new PackageVersion
{
TagName = $"{b.Name}",
ReleaseNotesMarkdown = string.Empty
});
}
/// <summary>
/// Setup the virtual environment for the package.
@ -88,14 +115,16 @@ public abstract class BaseGitPackage : BasePackage
VenvRunner = new PyVenvRunner(venvPath);
if (!VenvRunner.Exists())
{
await VenvRunner.Setup();
await VenvRunner.Setup().ConfigureAwait(false);
}
return VenvRunner;
}
public override async Task<IEnumerable<Release>> GetReleaseTags()
{
var allReleases = await GithubApi.GetAllReleases(Author, Name);
var allReleases = await GithubApi
.GetAllReleases(Author, Name)
.ConfigureAwait(false);
return allReleases;
}
@ -109,7 +138,10 @@ public abstract class BaseGitPackage : BasePackage
Directory.CreateDirectory(DownloadLocation.Replace($"{Name}.zip", ""));
}
await DownloadService.DownloadToFileAsync(downloadUrl, DownloadLocation, progress: progress);
await DownloadService
.DownloadToFileAsync(downloadUrl, DownloadLocation, progress: progress)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(100, message: "Download Complete"));
return version;
@ -117,7 +149,7 @@ public abstract class BaseGitPackage : BasePackage
public override async Task InstallPackage(IProgress<ProgressReport>? progress = null)
{
await UnzipPackage(progress);
await UnzipPackage(progress).ConfigureAwait(false);
progress?.Report(new ProgressReport(1f, $"{DisplayName} installed successfully"));
File.Delete(DownloadLocation);
}
@ -171,13 +203,14 @@ public abstract class BaseGitPackage : BasePackage
{
if (string.IsNullOrWhiteSpace(package.InstalledBranch))
{
var latestVersion = await GetLatestVersion();
var latestVersion = await GetLatestVersion().ConfigureAwait(false);
UpdateAvailable = latestVersion != currentVersion;
return latestVersion != currentVersion;
}
else
{
var allCommits = (await GetAllCommits(package.InstalledBranch))?.ToList();
var allCommits = (await GetAllCommits(package.InstalledBranch)
.ConfigureAwait(false))?.ToList();
if (allCommits == null || !allCommits.Any())
{
Logger.Warn("No commits found for {Package}", package.PackageName);
@ -200,18 +233,19 @@ public abstract class BaseGitPackage : BasePackage
{
if (string.IsNullOrWhiteSpace(installedPackage.InstalledBranch))
{
var releases = await GetAllReleases();
var releases = await GetAllReleases().ConfigureAwait(false);
var latestRelease = releases.First(x => includePrerelease || !x.Prerelease);
await DownloadPackage(latestRelease.TagName, false, progress);
await InstallPackage(progress);
await DownloadPackage(latestRelease.TagName, false, progress).ConfigureAwait(false);
await InstallPackage(progress).ConfigureAwait(false);
return latestRelease.TagName;
}
else
{
var allCommits = await GetAllCommits(installedPackage.InstalledBranch);
var allCommits = await GetAllCommits(
installedPackage.InstalledBranch).ConfigureAwait(false);
var latestCommit = allCommits.First();
await DownloadPackage(latestCommit.Sha, true, progress);
await InstallPackage(progress);
await InstallPackage(progress).ConfigureAwait(false);
return latestCommit.Sha;
}
}
@ -236,13 +270,13 @@ public abstract class BaseGitPackage : BasePackage
Logger.Warn("No process running for {Name}", Name);
return;
}
await process.StandardInput.WriteLineAsync(input);
await process.StandardInput.WriteLineAsync(input).ConfigureAwait(false);
}
public override async Task Shutdown()
{
if (VenvRunner?.Process == null) return;
VenvRunner.Dispose();
await VenvRunner.Process.WaitForExitAsync();
await VenvRunner.Process.WaitForExitAsync().ConfigureAwait(false);
}
}

2
StabilityMatrix.Core/Models/Packages/VladAutomatic.cs

@ -14,6 +14,8 @@ namespace StabilityMatrix.Core.Models.Packages;
public class VladAutomatic : BaseGitPackage
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
public override string Name => "automatic";
public override string DisplayName { get; set; } = "SD.Next Web UI";
public override string Author => "vladmandic";

151
StabilityMatrix.Core/Models/Packages/VoltaML.cs

@ -0,0 +1,151 @@
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
public class VoltaML : BaseGitPackage
{
public override string Name => "voltaML-fast-stable-diffusion";
public override string DisplayName { get; set; } = "VoltaML";
public override string Author => "VoltaML";
public override string Blurb => "Fast Stable Diffusion with support for AITemplate";
public override string LaunchCommand => "main.py";
public override Uri PreviewImageUri => new(
"https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8");
// There are releases but the manager just downloads the latest commit anyways,
// so we'll just limit to commit mode to be more consistent
public override bool ShouldIgnoreReleases => true;
public VoltaML(
IGithubApiCache githubApi,
ISettingsManager settingsManager,
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper) :
base(githubApi, settingsManager, downloadService, prerequisiteHelper)
{
}
// https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L86
public override Dictionary<SharedFolderType, string> SharedFolders => new()
{
[SharedFolderType.StableDiffusion] = "data/models",
[SharedFolderType.Lora] = "data/lora",
[SharedFolderType.TextualInversion] = "data/textual-inversion",
};
// https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L45
public override List<LaunchOptionDefinition> LaunchOptions => new List<LaunchOptionDefinition>
{
new()
{
Name = "Log Level",
Type = LaunchOptionType.Bool,
DefaultValue = "--log-level INFO",
Options =
{
"--log-level DEBUG",
"--log-level INFO",
"--log-level WARNING",
"--log-level ERROR",
"--log-level CRITICAL"
}
},
new()
{
Name = "Use ngrok to expose the API",
Type = LaunchOptionType.Bool,
Options = {"--ngrok"}
},
new()
{
Name = "Expose the API to the network",
Type = LaunchOptionType.Bool,
Options = {"--host"}
},
new()
{
Name = "Skip virtualenv check",
Type = LaunchOptionType.Bool,
InitialValue = true,
Options = {"--in-container"}
},
new()
{
Name = "Force VoltaML to use a specific type of PyTorch distribution",
Type = LaunchOptionType.Bool,
Options =
{
"--pytorch-type cpu",
"--pytorch-type cuda",
"--pytorch-type rocm",
"--pytorch-type directml",
"--pytorch-type intel",
"--pytorch-type vulkan"
}
},
new()
{
Name = "Run in tandem with the Discord bot",
Type = LaunchOptionType.Bool,
Options = {"--bot"}
},
new()
{
Name = "Enable Cloudflare R2 bucket upload support",
Type = LaunchOptionType.Bool,
Options = {"--enable-r2"}
},
new()
{
Name = "Port",
Type = LaunchOptionType.String,
DefaultValue = "5003",
Options = {"--port"}
},
new()
{
Name = "Only install requirements and exit",
Type = LaunchOptionType.Bool,
Options = {"--install-only"}
},
LaunchOptionDefinition.Extras
};
public override Task<string> GetLatestVersion() => Task.FromResult("main");
public override async Task InstallPackage(IProgress<ProgressReport>? progress = null)
{
await UnzipPackage(progress).ConfigureAwait(false);
// Setup venv
progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true));
using var venvRunner = new PyVenvRunner(Path.Combine(InstallLocation, "venv"));
await venvRunner.Setup().ConfigureAwait(false);
// Install requirements
progress?.Report(new ProgressReport(-1, "Installing Package Requirements", isIndeterminate: true));
await venvRunner
.PipInstall("rich packaging python-dotenv", InstallLocation, OnConsoleOutput)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(1, "Installing Package Requirements", isIndeterminate: false));
}
public override async Task RunPackage(string installedPackagePath, string arguments)
{
await SetupVenv(installedPackagePath).ConfigureAwait(false);
var args = $"\"{Path.Combine(installedPackagePath, LaunchCommand)}\" {arguments}";
VenvRunner?.RunDetached(
args.TrimEnd(),
outputDataReceived: OnConsoleOutput,
onExit: OnExit,
workingDirectory: installedPackagePath);
}
}
Loading…
Cancel
Save