diff --git a/StabilityMatrix.Avalonia/Assets.cs b/StabilityMatrix.Avalonia/Assets.cs index db08322e..c422abdb 100644 --- a/StabilityMatrix.Avalonia/Assets.cs +++ b/StabilityMatrix.Avalonia/Assets.cs @@ -136,6 +136,10 @@ internal static class Assets ) ); + [SupportedOSPlatform("windows")] + public static AvaloniaResource TkinterZip => + new("avares://StabilityMatrix.Avalonia/Assets/win-x64/tkinter_3_10_7.zip"); + public static IReadOnlyList DefaultCompletionTags { get; } = new[] { diff --git a/StabilityMatrix.Avalonia/Assets/win-x64/tkinter_3_10_7.zip b/StabilityMatrix.Avalonia/Assets/win-x64/tkinter_3_10_7.zip new file mode 100644 index 00000000..78a589bc Binary files /dev/null and b/StabilityMatrix.Avalonia/Assets/win-x64/tkinter_3_10_7.zip differ diff --git a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs index eee43c2c..04cb8ffd 100644 --- a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs @@ -233,6 +233,13 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper throw new NotImplementedException(); } + [UnsupportedOSPlatform("Linux")] + [UnsupportedOSPlatform("macOS")] + public Task InstallTkinterIfNecessary(IProgress? progress = null) + { + throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("Linux")] [UnsupportedOSPlatform("macOS")] public Task InstallVcRedistIfNecessary(IProgress? progress = null) diff --git a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs index ad77d03a..fbf8203c 100644 --- a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs @@ -43,6 +43,8 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper private string PortableGitInstallDir => Path.Combine(HomeDir, "PortableGit"); private string PortableGitDownloadPath => Path.Combine(HomeDir, "PortableGit.7z.exe"); private string GitExePath => Path.Combine(PortableGitInstallDir, "bin", "git.exe"); + private string TkinterZipPath => Path.Combine(AssetsDir, "tkinter.zip"); + private string TkinterExtractPath => PythonDir; public string GitBinPath => Path.Combine(PortableGitInstallDir, "bin"); public bool IsPythonInstalled => File.Exists(PythonDllPath); @@ -223,6 +225,9 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper pythonPthContent = pythonPthContent.Replace("#import site", "import site"); await File.WriteAllTextAsync(pythonPthPath, pythonPthContent); + // Install TKinter + await InstallTkinterIfNecessary(progress); + progress?.Report(new ProgressReport(1f, "Python install complete")); } finally @@ -235,6 +240,17 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper } } + [SupportedOSPlatform("windows")] + public async Task InstallTkinterIfNecessary(IProgress? progress = null) + { + if (!File.Exists(TkinterZipPath)) + { + await Assets.TkinterZip.ExtractTo(TkinterZipPath); + } + + await ArchiveHelper.Extract(TkinterZipPath, TkinterExtractPath, progress); + } + public async Task InstallGitIfNecessary(IProgress? progress = null) { if (File.Exists(GitExePath)) diff --git a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs new file mode 100644 index 00000000..20a651e6 --- /dev/null +++ b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs @@ -0,0 +1,162 @@ +using System.Text.RegularExpressions; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Models.Packages; + +[Singleton(typeof(BasePackage))] +public class KohyaSs : BaseGitPackage +{ + public KohyaSs( + IGithubApiCache githubApi, + ISettingsManager settingsManager, + IDownloadService downloadService, + IPrerequisiteHelper prerequisiteHelper + ) + : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } + + public override string Name => "kohya_ss"; + public override string DisplayName { get; set; } = "Kohya's GUI"; + public override string Author => "bmaltais"; + public override string Blurb => + "A Windows-focused Gradio GUI for Kohya's Stable Diffusion trainers"; + public override string LicenseType => "Apache-2.0"; + public override string LicenseUrl => + "https://github.com/bmaltais/kohya_ss/blob/master/LICENSE.md"; + public override string LaunchCommand => "kohya_gui.py"; + + public override Uri PreviewImageUri => + new( + "https://camo.githubusercontent.com/2170d2204816f428eec57ff87218f06344e0b4d91966343a6c5f0a76df91ec75/68747470733a2f2f696d672e796f75747562652e636f6d2f76692f6b35696d713031757655592f302e6a7067" + ); + public override string OutputFolderName => string.Empty; + + public override bool IsCompatible => HardwareHelper.HasNvidiaGpu(); + + public override TorchVersion GetRecommendedTorchVersion() => TorchVersion.Cuda; + + public override bool OfferInOneClickInstaller => false; + + public override async Task InstallPackage( + string installLocation, + TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, + IProgress? progress = null, + Action? onConsoleOutput = null + ) + { + if (Compat.IsWindows) + { + progress?.Report( + new ProgressReport(-1f, "Installing prerequisites...", isIndeterminate: true) + ); + await PrerequisiteHelper.InstallTkinterIfNecessary(progress).ConfigureAwait(false); + } + + progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); + // Setup venv + await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); + venvRunner.WorkingDirectory = installLocation; + await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false); + + var setupSmPath = Path.Combine(installLocation, "setup", "setup_sm.py"); + var setupText = """ + import setup_windows + import setup_common + + setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=False) + setup_windows.sync_bits_and_bytes_files() + setup_common.configure_accelerate(run_accelerate=False) + """; + await File.WriteAllTextAsync(setupSmPath, setupText).ConfigureAwait(false); + + // Install + venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput); + await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); + } + + public override async Task RunPackage( + string installedPackagePath, + string command, + string arguments, + Action? onConsoleOutput + ) + { + await SetupVenv(installedPackagePath).ConfigureAwait(false); + + var process = ProcessRunner.StartProcess( + Path.Combine(installedPackagePath, "venv", "Scripts", "accelerate.exe"), + "env", + installedPackagePath, + s => onConsoleOutput?.Invoke(new ProcessOutput { Text = s }) + ); + + await process.WaitForExitAsync().ConfigureAwait(false); + + void HandleConsoleOutput(ProcessOutput s) + { + onConsoleOutput?.Invoke(s); + + if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase)) + return; + + var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); + var match = regex.Match(s.Text); + if (!match.Success) + return; + + WebUrl = match.Value; + OnStartupComplete(WebUrl); + } + + var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; + + VenvRunner.EnvironmentVariables = GetEnvVars(); + VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); + } + + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; + public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.Cuda }; + public override List LaunchOptions => + new() { LaunchOptionDefinition.Extras }; + public override Dictionary>? SharedFolders { get; } + public override Dictionary< + SharedOutputType, + IReadOnlyList + >? SharedOutputFolders { get; } + + public override async Task GetLatestVersion() + { + var release = await GetLatestRelease().ConfigureAwait(false); + return release.TagName!; + } + + private Dictionary GetEnvVars() + { + // Set additional required environment variables + var env = new Dictionary(); + if (SettingsManager.Settings.EnvironmentVariables is not null) + { + env.Update(SettingsManager.Settings.EnvironmentVariables); + } + + var tkPath = Path.Combine( + SettingsManager.LibraryDir, + "Assets", + "Python310", + "tcl", + "tcl8.6" + ); + env["TCL_LIBRARY"] = tkPath; + env["TK_LIBRARY"] = tkPath; + + return env; + } +}