diff --git a/CHANGELOG.md b/CHANGELOG.md index 28a6b5be..af474bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2 ## v2.6.2 ### Changed - Backend changes for auto-update schema v3, supporting customizable release channels and faster downloads with zip compression +### Fixed +- Fixed `'accelerate' is not recognized as an internal or external command` error when starting training in kohya_ss +- Fixed some instances of `ModuleNotFoundError: No module named 'bitsandbytes.cuda_setup.paths'` error when using 8-bit optimizers in kohya_ss ## v2.6.1 ### Changed diff --git a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs index 45a13d3d..99b2a6a3 100644 --- a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs +++ b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs @@ -1,4 +1,5 @@ using System.Text.RegularExpressions; +using Python.Runtime; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; @@ -14,13 +15,19 @@ namespace StabilityMatrix.Core.Models.Packages; [Singleton(typeof(BasePackage))] public class KohyaSs : BaseGitPackage { + private readonly IPyRunner pyRunner; + public KohyaSs( IGithubApiCache githubApi, ISettingsManager settingsManager, IDownloadService downloadService, - IPrerequisiteHelper prerequisiteHelper + IPrerequisiteHelper prerequisiteHelper, + IPyRunner pyRunner ) - : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } + : base(githubApi, settingsManager, downloadService, prerequisiteHelper) + { + this.pyRunner = pyRunner; + } public override string Name => "kohya_ss"; public override string DisplayName { get; set; } = "kohya_ss"; @@ -147,6 +154,8 @@ public class KohyaSs : BaseGitPackage // Install venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput); await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); + + await venvRunner.PipInstall("bitsandbytes-windows").ConfigureAwait(false); } else if (Compat.IsLinux) { @@ -168,28 +177,64 @@ public class KohyaSs : BaseGitPackage await SetupVenv(installedPackagePath).ConfigureAwait(false); // update gui files to point to venv accelerate - var filesToUpdate = new[] + await pyRunner.RunInThreadWithLock(() => { - "lora_gui.py", - "dreambooth_gui.py", - "textual_inversion_gui.py", - Path.Combine("library", "wd14_caption_gui.py"), - "finetune_gui.py" - }; + var scope = Py.CreateScope(); + scope.Exec( + """ + import ast + + class StringReplacer(ast.NodeTransformer): + def __init__(self, old: str, new: str, replace_count: int = -1): + self.old = old + self.new = new + self.replace_count = replace_count + + def visit_Constant(self, node: ast.Constant) -> ast.Constant: + if isinstance(node.value, str) and self.old in node.value: + new_value = node.value.replace(self.old, self.new, self.replace_count) + node.value = new_value + return node + + def rewrite_module(self, module_text: str) -> str: + tree = ast.parse(module_text) + tree = self.visit(tree) + return ast.unparse(tree) + """ + ); - foreach (var file in filesToUpdate) - { - var path = Path.Combine(installedPackagePath, file); - var text = await File.ReadAllTextAsync(path).ConfigureAwait(false); var replacementAcceleratePath = Compat.IsWindows - ? @".\\venv\\scripts\\accelerate" + ? @".\venv\scripts\accelerate" : "./venv/bin/accelerate"; - text = text.Replace( - "run_cmd = f'accelerate launch", - $"run_cmd = f'{replacementAcceleratePath} launch" + + var replacer = scope.InvokeMethod( + "StringReplacer", + "accelerate".ToPython(), + $"{replacementAcceleratePath}".ToPython(), + 1.ToPython() ); - await File.WriteAllTextAsync(path, text).ConfigureAwait(false); - } + + var filesToUpdate = new[] + { + "lora_gui.py", + "dreambooth_gui.py", + "textual_inversion_gui.py", + Path.Combine("library", "wd14_caption_gui.py"), + "finetune_gui.py" + }; + + foreach (var file in filesToUpdate) + { + var path = Path.Combine(installedPackagePath, file); + var text = File.ReadAllText(path); + if (text.Contains(replacementAcceleratePath.Replace(@"\", @"\\"))) + continue; + + var result = replacer.InvokeMethod("rewrite_module", text.ToPython()); + var resultStr = result.ToString(); + File.WriteAllText(path, resultStr); + } + }); void HandleConsoleOutput(ProcessOutput s) { diff --git a/StabilityMatrix.Core/Models/Packages/VoltaML.cs b/StabilityMatrix.Core/Models/Packages/VoltaML.cs index eeef1ce3..933e4fb3 100644 --- a/StabilityMatrix.Core/Models/Packages/VoltaML.cs +++ b/StabilityMatrix.Core/Models/Packages/VoltaML.cs @@ -62,7 +62,7 @@ public class VoltaML : BaseGitPackage public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Mps }; + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl }; public override IEnumerable AvailableSharedFolderMethods => new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };