Browse Source

Fixes a couple kohya_ss errors introduced in the latest version

pull/269/head
JT 1 year ago
parent
commit
0ba3b82308
  1. 3
      CHANGELOG.md
  2. 83
      StabilityMatrix.Core/Models/Packages/KohyaSs.cs
  3. 2
      StabilityMatrix.Core/Models/Packages/VoltaML.cs

3
CHANGELOG.md

@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
## v2.6.2
### Changed
- Backend changes for auto-update schema v3, supporting customizable release channels and faster downloads with zip compression
### Fixed
- Fixed `'accelerate' is not recognized as an internal or external command` error when starting training in kohya_ss
- Fixed some instances of `ModuleNotFoundError: No module named 'bitsandbytes.cuda_setup.paths'` error when using 8-bit optimizers in kohya_ss
## v2.6.1
### Changed

83
StabilityMatrix.Core/Models/Packages/KohyaSs.cs

@ -1,4 +1,5 @@
using System.Text.RegularExpressions;
using Python.Runtime;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
@ -14,13 +15,19 @@ namespace StabilityMatrix.Core.Models.Packages;
[Singleton(typeof(BasePackage))]
public class KohyaSs : BaseGitPackage
{
private readonly IPyRunner pyRunner;
public KohyaSs(
IGithubApiCache githubApi,
ISettingsManager settingsManager,
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper
IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner
)
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
: base(githubApi, settingsManager, downloadService, prerequisiteHelper)
{
this.pyRunner = pyRunner;
}
public override string Name => "kohya_ss";
public override string DisplayName { get; set; } = "kohya_ss";
@ -147,6 +154,8 @@ public class KohyaSs : BaseGitPackage
// Install
venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput);
await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false);
await venvRunner.PipInstall("bitsandbytes-windows").ConfigureAwait(false);
}
else if (Compat.IsLinux)
{
@ -168,28 +177,64 @@ public class KohyaSs : BaseGitPackage
await SetupVenv(installedPackagePath).ConfigureAwait(false);
// update gui files to point to venv accelerate
var filesToUpdate = new[]
await pyRunner.RunInThreadWithLock(() =>
{
"lora_gui.py",
"dreambooth_gui.py",
"textual_inversion_gui.py",
Path.Combine("library", "wd14_caption_gui.py"),
"finetune_gui.py"
};
var scope = Py.CreateScope();
scope.Exec(
"""
import ast
class StringReplacer(ast.NodeTransformer):
def __init__(self, old: str, new: str, replace_count: int = -1):
self.old = old
self.new = new
self.replace_count = replace_count
def visit_Constant(self, node: ast.Constant) -> ast.Constant:
if isinstance(node.value, str) and self.old in node.value:
new_value = node.value.replace(self.old, self.new, self.replace_count)
node.value = new_value
return node
def rewrite_module(self, module_text: str) -> str:
tree = ast.parse(module_text)
tree = self.visit(tree)
return ast.unparse(tree)
"""
);
foreach (var file in filesToUpdate)
{
var path = Path.Combine(installedPackagePath, file);
var text = await File.ReadAllTextAsync(path).ConfigureAwait(false);
var replacementAcceleratePath = Compat.IsWindows
? @".\\venv\\scripts\\accelerate"
? @".\venv\scripts\accelerate"
: "./venv/bin/accelerate";
text = text.Replace(
"run_cmd = f'accelerate launch",
$"run_cmd = f'{replacementAcceleratePath} launch"
var replacer = scope.InvokeMethod(
"StringReplacer",
"accelerate".ToPython(),
$"{replacementAcceleratePath}".ToPython(),
1.ToPython()
);
await File.WriteAllTextAsync(path, text).ConfigureAwait(false);
}
var filesToUpdate = new[]
{
"lora_gui.py",
"dreambooth_gui.py",
"textual_inversion_gui.py",
Path.Combine("library", "wd14_caption_gui.py"),
"finetune_gui.py"
};
foreach (var file in filesToUpdate)
{
var path = Path.Combine(installedPackagePath, file);
var text = File.ReadAllText(path);
if (text.Contains(replacementAcceleratePath.Replace(@"\", @"\\")))
continue;
var result = replacer.InvokeMethod("rewrite_module", text.ToPython());
var resultStr = result.ToString();
File.WriteAllText(path, resultStr);
}
});
void HandleConsoleOutput(ProcessOutput s)
{

2
StabilityMatrix.Core/Models/Packages/VoltaML.cs

@ -62,7 +62,7 @@ public class VoltaML : BaseGitPackage
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
public override IEnumerable<TorchVersion> AvailableTorchVersions =>
new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Mps };
new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl };
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };

Loading…
Cancel
Save