Browse Source

Merge pull request #350 from ionite34/merge-main-to-dev-ad3b436

pull/324/head
Ionite 1 year ago committed by GitHub
parent
commit
493c3ec5a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      CHANGELOG.md
  2. 19
      StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs
  3. 29
      StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs
  4. 11
      StabilityMatrix.Avalonia/Views/MainWindow.axaml
  5. 15
      StabilityMatrix.Core/Exceptions/ProcessException.cs
  6. 102
      StabilityMatrix.Core/Helper/ArchiveHelper.cs
  7. 12
      StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs
  8. 27
      StabilityMatrix.Core/Helper/PrerequisiteHelper.cs
  9. 74
      StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
  10. 83
      StabilityMatrix.Core/Models/Packages/KohyaSs.cs
  11. 30
      StabilityMatrix.Core/Models/Packages/VladAutomatic.cs
  12. 2
      StabilityMatrix.Core/Models/Packages/VoltaML.cs
  13. 18
      StabilityMatrix.Core/Processes/ProcessResult.cs
  14. 12
      StabilityMatrix.Core/Processes/ProcessRunner.cs
  15. 47
      StabilityMatrix.Core/Python/PyRunner.cs

4
CHANGELOG.md

@ -20,6 +20,10 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
## v2.6.2 ## v2.6.2
### Changed ### Changed
- Backend changes for auto-update schema v3, supporting customizable release channels and faster downloads with zip compression - Backend changes for auto-update schema v3, supporting customizable release channels and faster downloads with zip compression
### Fixed
- Better error reporting including outputs for git subprocess errors during package install / update
- Fixed `'accelerate' is not recognized as an internal or external command` error when starting training in kohya_ss
- Fixed some instances of `ModuleNotFoundError: No module named 'bitsandbytes.cuda_setup.paths'` error when using 8-bit optimizers in kohya_ss
## v2.6.1 ## v2.6.1
### Changed ### Changed

19
StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs

@ -126,16 +126,23 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper
} }
} }
public async Task RunGit( /// <inheritdoc />
string? workingDirectory = null, public Task RunGit(
ProcessArgs args,
Action<ProcessOutput>? onProcessOutput = null, Action<ProcessOutput>? onProcessOutput = null,
params string[] args string? workingDirectory = null
) )
{ {
var command = // Async progress not supported on Unix
args.Length == 0 ? "git" : "git " + string.Join(" ", args.Select(ProcessRunner.Quote)); return RunGit(args, workingDirectory);
}
/// <inheritdoc />
public async Task RunGit(ProcessArgs args, string? workingDirectory = null)
{
var command = args.Prepend("git");
var result = await ProcessRunner.RunBashCommand(command, workingDirectory ?? ""); var result = await ProcessRunner.RunBashCommand(command.ToArray(), workingDirectory ?? "");
if (result.ExitCode != 0) if (result.ExitCode != 0)
{ {
Logger.Error( Logger.Error(

29
StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs

@ -6,6 +6,7 @@ using System.Threading.Tasks;
using Microsoft.Win32; using Microsoft.Win32;
using NLog; using NLog;
using Octokit; using Octokit;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Processes;
@ -64,23 +65,35 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
} }
public async Task RunGit( public async Task RunGit(
string? workingDirectory = null, ProcessArgs args,
Action<ProcessOutput>? onProcessOutput = null, Action<ProcessOutput>? onProcessOutput,
params string[] args string? workingDirectory = null
) )
{ {
var process = ProcessRunner.StartAnsiProcess( var process = ProcessRunner.StartAnsiProcess(
GitExePath, GitExePath,
args, args.ToArray(),
workingDirectory: workingDirectory, workingDirectory,
onProcessOutput,
environmentVariables: new Dictionary<string, string> environmentVariables: new Dictionary<string, string>
{ {
{ "PATH", Compat.GetEnvPathWithExtensions(GitBinPath) } { "PATH", Compat.GetEnvPathWithExtensions(GitBinPath) }
}, }
outputDataReceived: onProcessOutput
); );
await process.WaitForExitAsync().ConfigureAwait(false);
if (process.ExitCode != 0)
{
throw new ProcessException($"Git exited with code {process.ExitCode}");
}
}
public async Task RunGit(ProcessArgs args, string? workingDirectory = null)
{
var result = await ProcessRunner
.GetProcessResultAsync(GitExePath, args, workingDirectory)
.ConfigureAwait(false);
await ProcessRunner.WaitForExitConditionAsync(process); result.EnsureSuccessExitCode();
} }
public async Task<string> GetGitOutput(string? workingDirectory = null, params string[] args) public async Task<string> GetGitOutput(string? workingDirectory = null, params string[] args)

11
StabilityMatrix.Avalonia/Views/MainWindow.axaml

@ -21,6 +21,15 @@
DockProperties.IsDropEnabled="True" DockProperties.IsDropEnabled="True"
x:Class="StabilityMatrix.Avalonia.Views.MainWindow"> x:Class="StabilityMatrix.Avalonia.Views.MainWindow">
<controls:AppWindowBase.Resources>
<SolidColorBrush x:Key="BrushB0" Color="#FFFFFFFF" />
<DrawingImage x:Key="BrandsPatreonSymbolWhite">
<DrawingGroup>
<GeometryDrawing Brush="{DynamicResource BrushB0}" Geometry="F1 M1033.05 324.45C1032.86 186.55 925.46 73.53 799.45 32.75C642.97 -17.89 436.59 -10.55 287.17 59.95C106.07 145.41 49.18 332.61 47.06 519.31C45.32 672.81 60.64 1077.1 288.68 1079.98C458.12 1082.13 483.35 863.8 561.75 758.65C617.53 683.84 689.35 662.71 777.76 640.83C929.71 603.22 1033.27 483.3 1033.05 324.45Z" />
</DrawingGroup>
</DrawingImage>
</controls:AppWindowBase.Resources>
<Grid RowDefinitions="Auto,Auto,*"> <Grid RowDefinitions="Auto,Auto,*">
<Grid Name="TitleBarHost" <Grid Name="TitleBarHost"
ColumnDefinitions="Auto,Auto,*,Auto" ColumnDefinitions="Auto,Auto,*,Auto"
@ -101,7 +110,7 @@
Content="{x:Static lang:Resources.Label_BecomeAPatron}" Content="{x:Static lang:Resources.Label_BecomeAPatron}"
Tapped="PatreonPatreonItem_OnTapped"> Tapped="PatreonPatreonItem_OnTapped">
<ui:NavigationViewItem.IconSource> <ui:NavigationViewItem.IconSource>
<controls:FASymbolIconSource Symbol="fa-brands fa-patreon"/> <ui:ImageIconSource Source="{StaticResource BrandsPatreonSymbolWhite}" />
</ui:NavigationViewItem.IconSource> </ui:NavigationViewItem.IconSource>
</ui:NavigationViewItem> </ui:NavigationViewItem>

15
StabilityMatrix.Core/Exceptions/ProcessException.cs

@ -1,11 +1,22 @@
namespace StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Exceptions;
/// <summary> /// <summary>
/// Exception that is thrown when a process fails. /// Exception that is thrown when a process fails.
/// </summary> /// </summary>
public class ProcessException : Exception public class ProcessException : Exception
{ {
public ProcessException(string message) : base(message) public ProcessResult? ProcessResult { get; }
public ProcessException(string message)
: base(message) { }
public ProcessException(ProcessResult processResult)
: base(
$"Process {processResult.ProcessName} exited with code {processResult.ExitCode}. {{StdOut = {processResult.StandardOutput}, StdErr = {processResult.StandardError}}}"
)
{ {
ProcessResult = processResult;
} }
} }

102
StabilityMatrix.Core/Helper/ArchiveHelper.cs

@ -15,7 +15,6 @@ namespace StabilityMatrix.Core.Helper;
public record struct ArchiveInfo(ulong Size, ulong CompressedSize); public record struct ArchiveInfo(ulong Size, ulong CompressedSize);
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public static partial class ArchiveHelper public static partial class ArchiveHelper
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -59,7 +58,7 @@ public static partial class ArchiveHelper
public static async Task<ArchiveInfo> TestArchive(string archivePath) public static async Task<ArchiveInfo> TestArchive(string archivePath)
{ {
var process = ProcessRunner.StartAnsiProcess(SevenZipPath, new[] {"t", archivePath}); var process = ProcessRunner.StartAnsiProcess(SevenZipPath, new[] { "t", archivePath });
await process.WaitForExitAsync(); await process.WaitForExitAsync();
var output = await process.StandardOutput.ReadToEndAsync(); var output = await process.StandardOutput.ReadToEndAsync();
var matches = Regex7ZOutput().Matches(output); var matches = Regex7ZOutput().Matches(output);
@ -74,11 +73,15 @@ public static partial class ArchiveHelper
var sourceParent = Directory.GetParent(sourceDirectory)?.FullName ?? ""; var sourceParent = Directory.GetParent(sourceDirectory)?.FullName ?? "";
// We must pass in as `directory\` for archive path to be correct // We must pass in as `directory\` for archive path to be correct
var sourceDirName = new DirectoryInfo(sourceDirectory).Name; var sourceDirName = new DirectoryInfo(sourceDirectory).Name;
var process = ProcessRunner.StartAnsiProcess(SevenZipPath, new[]
{ var result = await ProcessRunner
"a", archivePath, sourceDirName + @"\", "-y" .GetProcessResultAsync(
}, workingDirectory: sourceParent); SevenZipPath,
await ProcessRunner.WaitForExitConditionAsync(process); new[] { "a", archivePath, sourceDirName + @"\", "-y" },
workingDirectory: sourceParent
)
.ConfigureAwait(false);
result.EnsureSuccessExitCode();
} }
public static async Task<ArchiveInfo> Extract7Z(string archivePath, string extractDirectory) public static async Task<ArchiveInfo> Extract7Z(string archivePath, string extractDirectory)
@ -113,12 +116,17 @@ public static partial class ArchiveHelper
} }
} }
public static async Task<ArchiveInfo> Extract7Z(string archivePath, string extractDirectory, IProgress<ProgressReport> progress) public static async Task<ArchiveInfo> Extract7Z(
string archivePath,
string extractDirectory,
IProgress<ProgressReport> progress
)
{ {
var outputStore = new StringBuilder(); var outputStore = new StringBuilder();
var onOutput = new Action<string?>(s => var onOutput = new Action<string?>(s =>
{ {
if (s == null) return; if (s == null)
return;
// Parse progress // Parse progress
Logger.Trace($"7z: {s}"); Logger.Trace($"7z: {s}");
@ -128,7 +136,14 @@ public static partial class ArchiveHelper
{ {
var percent = int.Parse(match.Groups[1].Value); var percent = int.Parse(match.Groups[1].Value);
var currentFile = match.Groups[2].Value; var currentFile = match.Groups[2].Value;
progress.Report(new ProgressReport(percent / (float) 100, "Extracting", currentFile, type: ProgressType.Extract)); progress.Report(
new ProgressReport(
percent / (float)100,
"Extracting",
currentFile,
type: ProgressType.Extract
)
);
} }
}); });
progress.Report(new ProgressReport(-1, isIndeterminate: true, type: ProgressType.Extract)); progress.Report(new ProgressReport(-1, isIndeterminate: true, type: ProgressType.Extract));
@ -216,7 +231,11 @@ public static partial class ArchiveHelper
/// <param name="progress"></param> /// <param name="progress"></param>
/// <param name="archivePath"></param> /// <param name="archivePath"></param>
/// <param name="outputDirectory">Output directory, created if does not exist.</param> /// <param name="outputDirectory">Output directory, created if does not exist.</param>
public static async Task Extract(string archivePath, string outputDirectory, IProgress<ProgressReport>? progress = default) public static async Task Extract(
string archivePath,
string outputDirectory,
IProgress<ProgressReport>? progress = default
)
{ {
Directory.CreateDirectory(outputDirectory); Directory.CreateDirectory(outputDirectory);
progress?.Report(new ProgressReport(-1, isIndeterminate: true)); progress?.Report(new ProgressReport(-1, isIndeterminate: true));
@ -229,11 +248,12 @@ public static partial class ArchiveHelper
// If not available, use the size of the archive file // If not available, use the size of the archive file
if (total == 0) if (total == 0)
{ {
total = (ulong) new FileInfo(archivePath).Length; total = (ulong)new FileInfo(archivePath).Length;
} }
// Create an DispatchTimer that monitors the progress of the extraction // Create an DispatchTimer that monitors the progress of the extraction
var progressMonitor = progress switch { var progressMonitor = progress switch
{
null => null, null => null,
_ => new Timer(TimeSpan.FromMilliseconds(36)) _ => new Timer(TimeSpan.FromMilliseconds(36))
}; };
@ -242,34 +262,38 @@ public static partial class ArchiveHelper
{ {
progressMonitor.Elapsed += (_, _) => progressMonitor.Elapsed += (_, _) =>
{ {
if (count == 0) return; if (count == 0)
return;
progress!.Report(new ProgressReport(count, total, message: "Extracting")); progress!.Report(new ProgressReport(count, total, message: "Extracting"));
}; };
} }
await Task.Factory.StartNew(() => await Task.Factory.StartNew(
{ () =>
var extractOptions = new ExtractionOptions
{ {
Overwrite = true, var extractOptions = new ExtractionOptions
ExtractFullPath = true, {
}; Overwrite = true,
using var stream = File.OpenRead(archivePath); ExtractFullPath = true,
using var archive = ReaderFactory.Open(stream); };
using var stream = File.OpenRead(archivePath);
using var archive = ReaderFactory.Open(stream);
// Start the progress reporting timer // Start the progress reporting timer
progressMonitor?.Start(); progressMonitor?.Start();
while (archive.MoveToNextEntry()) while (archive.MoveToNextEntry())
{
var entry = archive.Entry;
if (!entry.IsDirectory)
{ {
count += (ulong) entry.CompressedSize; var entry = archive.Entry;
if (!entry.IsDirectory)
{
count += (ulong)entry.CompressedSize;
}
archive.WriteEntryToDirectory(outputDirectory, extractOptions);
} }
archive.WriteEntryToDirectory(outputDirectory, extractOptions); },
} TaskCreationOptions.LongRunning
}, TaskCreationOptions.LongRunning); );
progress?.Report(new ProgressReport(progress: 1, message: "Done extracting")); progress?.Report(new ProgressReport(progress: 1, message: "Done extracting"));
progressMonitor?.Stop(); progressMonitor?.Stop();
@ -328,14 +352,18 @@ public static partial class ArchiveHelper
{ {
// Not sure why but symlink entries have a key that ends with a space // Not sure why but symlink entries have a key that ends with a space
// and some broken path suffix, so we'll remove everything after the last space // and some broken path suffix, so we'll remove everything after the last space
Logger.Debug($"Checking if output path {outputPath} contains space char: {outputPath.Contains(' ')}"); Logger.Debug(
$"Checking if output path {outputPath} contains space char: {outputPath.Contains(' ')}"
);
if (outputPath.Contains(' ')) if (outputPath.Contains(' '))
{ {
outputPath = outputPath[..outputPath.LastIndexOf(' ')]; outputPath = outputPath[..outputPath.LastIndexOf(' ')];
} }
Logger.Debug($"Extracting symbolic link [{entry.Key.ToRepr()}] " + Logger.Debug(
$"({outputPath.ToRepr()} to {entry.LinkTarget.ToRepr()})"); $"Extracting symbolic link [{entry.Key.ToRepr()}] "
+ $"({outputPath.ToRepr()} to {entry.LinkTarget.ToRepr()})"
);
// Try to write link, if fail, continue copy file // Try to write link, if fail, continue copy file
try try
{ {
@ -346,7 +374,9 @@ public static partial class ArchiveHelper
} }
catch (IOException e) catch (IOException e)
{ {
Logger.Warn($"Could not extract symbolic link, copying file instead: {e.Message}"); Logger.Warn(
$"Could not extract symbolic link, copying file instead: {e.Message}"
);
} }
} }

12
StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs

@ -22,10 +22,16 @@ public interface IPrerequisiteHelper
/// Run embedded git with the given arguments. /// Run embedded git with the given arguments.
/// </summary> /// </summary>
Task RunGit( Task RunGit(
string? workingDirectory = null, ProcessArgs args,
Action<ProcessOutput>? onProcessOutput = null, Action<ProcessOutput>? onProcessOutput,
params string[] args string? workingDirectory = null
); );
/// <summary>
/// Run embedded git with the given arguments.
/// </summary>
Task RunGit(ProcessArgs args, string? workingDirectory = null);
Task<string> GetGitOutput(string? workingDirectory = null, params string[] args); Task<string> GetGitOutput(string? workingDirectory = null, params string[] args);
Task InstallTkinterIfNecessary(IProgress<ProgressReport>? progress = null); Task InstallTkinterIfNecessary(IProgress<ProgressReport>? progress = null);
} }

27
StabilityMatrix.Core/Helper/PrerequisiteHelper.cs

@ -1,8 +1,10 @@
using System.Reflection; using System.Diagnostics;
using System.Reflection;
using System.Runtime.Versioning; using System.Runtime.Versioning;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Win32; using Microsoft.Win32;
using Octokit; using Octokit;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
@ -61,18 +63,31 @@ public class PrerequisiteHelper : IPrerequisiteHelper
} }
public async Task RunGit( public async Task RunGit(
string? workingDirectory = null, ProcessArgs args,
Action<ProcessOutput>? onProcessOutput = null, Action<ProcessOutput>? onProcessOutput,
params string[] args string? workingDirectory = null
) )
{ {
var process = ProcessRunner.StartAnsiProcess( var process = ProcessRunner.StartAnsiProcess(
GitExePath, GitExePath,
args, args.ToArray(),
workingDirectory, workingDirectory,
onProcessOutput onProcessOutput
); );
await ProcessRunner.WaitForExitConditionAsync(process).ConfigureAwait(false); await process.WaitForExitAsync().ConfigureAwait(false);
if (process.ExitCode != 0)
{
throw new ProcessException($"Git exited with code {process.ExitCode}");
}
}
public async Task RunGit(ProcessArgs args, string? workingDirectory = null)
{
var result = await ProcessRunner
.GetProcessResultAsync(GitExePath, args, workingDirectory)
.ConfigureAwait(false);
result.EnsureSuccessExitCode();
} }
public async Task<string> GetGitOutput(string? workingDirectory = null, params string[] args) public async Task<string> GetGitOutput(string? workingDirectory = null, params string[] args)

74
StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs

@ -185,13 +185,14 @@ public abstract class BaseGitPackage : BasePackage
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
null, new[]
null, {
"clone", "clone",
"--branch", "--branch",
versionOptions.VersionTag, versionOptions.VersionTag,
GithubUrl, GithubUrl,
$"\"{installLocation}\"" installLocation
}
) )
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -199,13 +200,14 @@ public abstract class BaseGitPackage : BasePackage
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
null, new[]
null, {
"clone", "clone",
"--branch", "--branch",
versionOptions.BranchName, versionOptions.BranchName,
GithubUrl, GithubUrl,
$"\"{installLocation}\"" installLocation
}
) )
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -213,7 +215,7 @@ public abstract class BaseGitPackage : BasePackage
if (!versionOptions.IsLatest && !string.IsNullOrWhiteSpace(versionOptions.CommitHash)) if (!versionOptions.IsLatest && !string.IsNullOrWhiteSpace(versionOptions.CommitHash))
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit(installLocation, null, "checkout", versionOptions.CommitHash) .RunGit(new[] { "checkout", versionOptions.CommitHash }, installLocation)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -327,12 +329,9 @@ public abstract class BaseGitPackage : BasePackage
.ConfigureAwait(false); .ConfigureAwait(false);
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installedPackage.FullPath!, new[] { "remote", "add", "origin", GithubUrl },
onConsoleOutput, onConsoleOutput,
"remote", installedPackage.FullPath
"add",
"origin",
GithubUrl
) )
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -341,7 +340,7 @@ public abstract class BaseGitPackage : BasePackage
{ {
progress?.Report(new ProgressReport(-1f, "Fetching tags...", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Fetching tags...", isIndeterminate: true));
await PrerequisiteHelper await PrerequisiteHelper
.RunGit(installedPackage.FullPath!, onConsoleOutput, "fetch", "--tags") .RunGit(new[] { "fetch", "--tags" }, onConsoleOutput, installedPackage.FullPath)
.ConfigureAwait(false); .ConfigureAwait(false);
progress?.Report( progress?.Report(
@ -353,11 +352,9 @@ public abstract class BaseGitPackage : BasePackage
); );
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installedPackage.FullPath!, new[] { "checkout", versionOptions.VersionTag, "--force" },
onConsoleOutput, onConsoleOutput,
"checkout", installedPackage.FullPath
versionOptions.VersionTag,
"--force"
) )
.ConfigureAwait(false); .ConfigureAwait(false);
@ -381,7 +378,7 @@ public abstract class BaseGitPackage : BasePackage
// fetch // fetch
progress?.Report(new ProgressReport(-1f, "Fetching data...", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Fetching data...", isIndeterminate: true));
await PrerequisiteHelper await PrerequisiteHelper
.RunGit(installedPackage.FullPath!, onConsoleOutput, "fetch") .RunGit("fetch", onConsoleOutput, installedPackage.FullPath)
.ConfigureAwait(false); .ConfigureAwait(false);
if (versionOptions.IsLatest) if (versionOptions.IsLatest)
@ -396,11 +393,9 @@ public abstract class BaseGitPackage : BasePackage
); );
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installedPackage.FullPath!, new[] { "checkout", versionOptions.BranchName!, "--force" },
onConsoleOutput, onConsoleOutput,
"checkout", installedPackage.FullPath
versionOptions.BranchName,
"--force"
) )
.ConfigureAwait(false); .ConfigureAwait(false);
@ -408,12 +403,15 @@ public abstract class BaseGitPackage : BasePackage
progress?.Report(new ProgressReport(-1f, "Pulling changes...", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Pulling changes...", isIndeterminate: true));
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installedPackage.FullPath!, new[]
{
"pull",
"--autostash",
"origin",
installedPackage.Version.InstalledBranch!
},
onConsoleOutput, onConsoleOutput,
"pull", installedPackage.FullPath!
"--autostash",
"origin",
installedPackage.Version.InstalledBranch
) )
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -429,11 +427,9 @@ public abstract class BaseGitPackage : BasePackage
); );
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installedPackage.FullPath!, new[] { "checkout", versionOptions.CommitHash!, "--force" },
onConsoleOutput, onConsoleOutput,
"checkout", installedPackage.FullPath
versionOptions.CommitHash,
"--force"
) )
.ConfigureAwait(false); .ConfigureAwait(false);
} }

83
StabilityMatrix.Core/Models/Packages/KohyaSs.cs

@ -1,4 +1,5 @@
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using Python.Runtime;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper;
@ -14,13 +15,19 @@ namespace StabilityMatrix.Core.Models.Packages;
[Singleton(typeof(BasePackage))] [Singleton(typeof(BasePackage))]
public class KohyaSs : BaseGitPackage public class KohyaSs : BaseGitPackage
{ {
private readonly IPyRunner pyRunner;
public KohyaSs( public KohyaSs(
IGithubApiCache githubApi, IGithubApiCache githubApi,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IDownloadService downloadService, IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner
) )
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } : base(githubApi, settingsManager, downloadService, prerequisiteHelper)
{
this.pyRunner = pyRunner;
}
public override string Name => "kohya_ss"; public override string Name => "kohya_ss";
public override string DisplayName { get; set; } = "kohya_ss"; public override string DisplayName { get; set; } = "kohya_ss";
@ -147,6 +154,8 @@ public class KohyaSs : BaseGitPackage
// Install // Install
venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput); venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput);
await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false);
await venvRunner.PipInstall("bitsandbytes-windows").ConfigureAwait(false);
} }
else if (Compat.IsLinux) else if (Compat.IsLinux)
{ {
@ -168,28 +177,64 @@ public class KohyaSs : BaseGitPackage
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
// update gui files to point to venv accelerate // update gui files to point to venv accelerate
var filesToUpdate = new[] await pyRunner.RunInThreadWithLock(() =>
{ {
"lora_gui.py", var scope = Py.CreateScope();
"dreambooth_gui.py", scope.Exec(
"textual_inversion_gui.py", """
Path.Combine("library", "wd14_caption_gui.py"), import ast
"finetune_gui.py"
}; class StringReplacer(ast.NodeTransformer):
def __init__(self, old: str, new: str, replace_count: int = -1):
self.old = old
self.new = new
self.replace_count = replace_count
def visit_Constant(self, node: ast.Constant) -> ast.Constant:
if isinstance(node.value, str) and self.old in node.value:
new_value = node.value.replace(self.old, self.new, self.replace_count)
node.value = new_value
return node
def rewrite_module(self, module_text: str) -> str:
tree = ast.parse(module_text)
tree = self.visit(tree)
return ast.unparse(tree)
"""
);
foreach (var file in filesToUpdate)
{
var path = Path.Combine(installedPackagePath, file);
var text = await File.ReadAllTextAsync(path).ConfigureAwait(false);
var replacementAcceleratePath = Compat.IsWindows var replacementAcceleratePath = Compat.IsWindows
? @".\\venv\\scripts\\accelerate" ? @".\venv\scripts\accelerate"
: "./venv/bin/accelerate"; : "./venv/bin/accelerate";
text = text.Replace(
"run_cmd = f'accelerate launch", var replacer = scope.InvokeMethod(
$"run_cmd = f'{replacementAcceleratePath} launch" "StringReplacer",
"accelerate".ToPython(),
$"{replacementAcceleratePath}".ToPython(),
1.ToPython()
); );
await File.WriteAllTextAsync(path, text).ConfigureAwait(false);
} var filesToUpdate = new[]
{
"lora_gui.py",
"dreambooth_gui.py",
"textual_inversion_gui.py",
Path.Combine("library", "wd14_caption_gui.py"),
"finetune_gui.py"
};
foreach (var file in filesToUpdate)
{
var path = Path.Combine(installedPackagePath, file);
var text = File.ReadAllText(path);
if (text.Contains(replacementAcceleratePath.Replace(@"\", @"\\")))
continue;
var result = replacer.InvokeMethod("rewrite_module", text.ToPython());
var resultStr = result.ToString();
File.WriteAllText(path, resultStr);
}
});
void HandleConsoleOutput(ProcessOutput s) void HandleConsoleOutput(ProcessOutput s)
{ {

30
StabilityMatrix.Core/Models/Packages/VladAutomatic.cs

@ -242,29 +242,28 @@ public class VladAutomatic : BaseGitPackage
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installDir.Parent ?? "", new[] { "clone", "https://github.com/vladmandic/automatic", installDir.Name },
null, installDir.Parent?.FullPath ?? ""
"clone",
"https://github.com/vladmandic/automatic",
installDir.Name
) )
.ConfigureAwait(false); .ConfigureAwait(false);
await PrerequisiteHelper await PrerequisiteHelper
.RunGit(installLocation, null, "checkout", downloadOptions.CommitHash) .RunGit(new[] { "checkout", downloadOptions.CommitHash }, installLocation)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
else if (!string.IsNullOrWhiteSpace(downloadOptions.BranchName)) else if (!string.IsNullOrWhiteSpace(downloadOptions.BranchName))
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installDir.Parent ?? "", new[]
null, {
"clone", "clone",
"-b", "-b",
downloadOptions.BranchName, downloadOptions.BranchName,
"https://github.com/vladmandic/automatic", "https://github.com/vladmandic/automatic",
installDir.Name installDir.Name
},
installDir.Parent?.FullPath ?? ""
) )
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -325,10 +324,9 @@ public class VladAutomatic : BaseGitPackage
await PrerequisiteHelper await PrerequisiteHelper
.RunGit( .RunGit(
installedPackage.FullPath, new[] { "checkout", versionOptions.BranchName! },
onConsoleOutput, onConsoleOutput,
"checkout", installedPackage.FullPath
versionOptions.BranchName
) )
.ConfigureAwait(false); .ConfigureAwait(false);

2
StabilityMatrix.Core/Models/Packages/VoltaML.cs

@ -62,7 +62,7 @@ public class VoltaML : BaseGitPackage
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
public override IEnumerable<TorchVersion> AvailableTorchVersions => public override IEnumerable<TorchVersion> AvailableTorchVersions =>
new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Mps }; new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl };
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };

18
StabilityMatrix.Core/Processes/ProcessResult.cs

@ -1,8 +1,24 @@
namespace StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Exceptions;
namespace StabilityMatrix.Core.Processes;
public readonly record struct ProcessResult public readonly record struct ProcessResult
{ {
public required int ExitCode { get; init; } public required int ExitCode { get; init; }
public string? StandardOutput { get; init; } public string? StandardOutput { get; init; }
public string? StandardError { get; init; } public string? StandardError { get; init; }
public string? ProcessName { get; init; }
public TimeSpan Elapsed { get; init; }
public bool IsSuccessExitCode => ExitCode == 0;
public void EnsureSuccessExitCode()
{
if (!IsSuccessExitCode)
{
throw new ProcessException(this);
}
}
} }

12
StabilityMatrix.Core/Processes/ProcessRunner.cs

@ -208,7 +208,9 @@ public static class ProcessRunner
{ {
ExitCode = process.ExitCode, ExitCode = process.ExitCode,
StandardOutput = stdout, StandardOutput = stdout,
StandardError = stderr StandardError = stderr,
ProcessName = process.MachineName,
Elapsed = process.ExitTime - process.StartTime
}; };
} }
@ -425,6 +427,14 @@ public static class ProcessRunner
CancellationToken cancelToken = default CancellationToken cancelToken = default
) )
{ {
if (process is AnsiProcess)
{
throw new ArgumentException(
$"{nameof(WaitForExitConditionAsync)} does not support AnsiProcess, which uses custom async data reading",
nameof(process)
);
}
var stdout = new StringBuilder(); var stdout = new StringBuilder();
var stderr = new StringBuilder(); var stderr = new StringBuilder();
process.OutputDataReceived += (_, args) => stdout.Append(args.Data); process.OutputDataReceived += (_, args) => stdout.Append(args.Data);

47
StabilityMatrix.Core/Python/PyRunner.cs

@ -124,8 +124,11 @@ public class PyRunner : IPyRunner
{ {
throw new FileNotFoundException("get-pip not found", GetPipPath); throw new FileNotFoundException("get-pip not found", GetPipPath);
} }
var p = ProcessRunner.StartAnsiProcess(PythonExePath, "-m get-pip");
await ProcessRunner.WaitForExitConditionAsync(p); var result = await ProcessRunner
.GetProcessResultAsync(PythonExePath, "-m get-pip")
.ConfigureAwait(false);
result.EnsureSuccessExitCode();
} }
/// <summary> /// <summary>
@ -137,8 +140,10 @@ public class PyRunner : IPyRunner
{ {
throw new FileNotFoundException("pip not found", PipExePath); throw new FileNotFoundException("pip not found", PipExePath);
} }
var p = ProcessRunner.StartAnsiProcess(PipExePath, $"install {package}"); var result = await ProcessRunner
await ProcessRunner.WaitForExitConditionAsync(p); .GetProcessResultAsync(PipExePath, $"install {package}")
.ConfigureAwait(false);
result.EnsureSuccessExitCode();
} }
/// <summary> /// <summary>
@ -159,15 +164,16 @@ public class PyRunner : IPyRunner
try try
{ {
return await Task.Run( return await Task.Run(
() => () =>
{
using (Py.GIL())
{ {
return func(); using (Py.GIL())
} {
}, return func();
cancelToken }
); },
cancelToken
)
.ConfigureAwait(false);
} }
finally finally
{ {
@ -193,15 +199,16 @@ public class PyRunner : IPyRunner
try try
{ {
await Task.Run( await Task.Run(
() => () =>
{
using (Py.GIL())
{ {
action(); using (Py.GIL())
} {
}, action();
cancelToken }
); },
cancelToken
)
.ConfigureAwait(false);
} }
finally finally
{ {

Loading…
Cancel
Save