|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
using System.Text.RegularExpressions; |
|
|
|
|
using Python.Runtime; |
|
|
|
|
using StabilityMatrix.Core.Attributes; |
|
|
|
|
using StabilityMatrix.Core.Extensions; |
|
|
|
|
using StabilityMatrix.Core.Helper; |
|
|
|
@ -14,13 +15,19 @@ namespace StabilityMatrix.Core.Models.Packages;
|
|
|
|
|
[Singleton(typeof(BasePackage))] |
|
|
|
|
public class KohyaSs : BaseGitPackage |
|
|
|
|
{ |
|
|
|
|
private readonly IPyRunner pyRunner; |
|
|
|
|
|
|
|
|
|
public KohyaSs( |
|
|
|
|
IGithubApiCache githubApi, |
|
|
|
|
ISettingsManager settingsManager, |
|
|
|
|
IDownloadService downloadService, |
|
|
|
|
IPrerequisiteHelper prerequisiteHelper |
|
|
|
|
IPrerequisiteHelper prerequisiteHelper, |
|
|
|
|
IPyRunner pyRunner |
|
|
|
|
) |
|
|
|
|
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } |
|
|
|
|
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) |
|
|
|
|
{ |
|
|
|
|
this.pyRunner = pyRunner; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public override string Name => "kohya_ss"; |
|
|
|
|
public override string DisplayName { get; set; } = "kohya_ss"; |
|
|
|
@ -147,6 +154,8 @@ public class KohyaSs : BaseGitPackage
|
|
|
|
|
// Install |
|
|
|
|
venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput); |
|
|
|
|
await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); |
|
|
|
|
|
|
|
|
|
await venvRunner.PipInstall("bitsandbytes-windows").ConfigureAwait(false); |
|
|
|
|
} |
|
|
|
|
else if (Compat.IsLinux) |
|
|
|
|
{ |
|
|
|
@ -168,6 +177,43 @@ public class KohyaSs : BaseGitPackage
|
|
|
|
|
await SetupVenv(installedPackagePath).ConfigureAwait(false); |
|
|
|
|
|
|
|
|
|
// update gui files to point to venv accelerate |
|
|
|
|
await pyRunner.RunInThreadWithLock(() => |
|
|
|
|
{ |
|
|
|
|
var scope = Py.CreateScope(); |
|
|
|
|
scope.Exec( |
|
|
|
|
"""
|
|
|
|
|
import ast |
|
|
|
|
|
|
|
|
|
class StringReplacer(ast.NodeTransformer): |
|
|
|
|
def __init__(self, old: str, new: str, replace_count: int = -1): |
|
|
|
|
self.old = old |
|
|
|
|
self.new = new |
|
|
|
|
self.replace_count = replace_count |
|
|
|
|
|
|
|
|
|
def visit_Constant(self, node: ast.Constant) -> ast.Constant: |
|
|
|
|
if isinstance(node.value, str) and self.old in node.value: |
|
|
|
|
new_value = node.value.replace(self.old, self.new, self.replace_count) |
|
|
|
|
node.value = new_value |
|
|
|
|
return node |
|
|
|
|
|
|
|
|
|
def rewrite_module(self, module_text: str) -> str: |
|
|
|
|
tree = ast.parse(module_text) |
|
|
|
|
tree = self.visit(tree) |
|
|
|
|
return ast.unparse(tree) |
|
|
|
|
"""
|
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
var replacementAcceleratePath = Compat.IsWindows |
|
|
|
|
? @".\venv\scripts\accelerate" |
|
|
|
|
: "./venv/bin/accelerate"; |
|
|
|
|
|
|
|
|
|
var replacer = scope.InvokeMethod( |
|
|
|
|
"StringReplacer", |
|
|
|
|
"accelerate".ToPython(), |
|
|
|
|
$"{replacementAcceleratePath}".ToPython(), |
|
|
|
|
1.ToPython() |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
var filesToUpdate = new[] |
|
|
|
|
{ |
|
|
|
|
"lora_gui.py", |
|
|
|
@ -180,16 +226,15 @@ public class KohyaSs : BaseGitPackage
|
|
|
|
|
foreach (var file in filesToUpdate) |
|
|
|
|
{ |
|
|
|
|
var path = Path.Combine(installedPackagePath, file); |
|
|
|
|
var text = await File.ReadAllTextAsync(path).ConfigureAwait(false); |
|
|
|
|
var replacementAcceleratePath = Compat.IsWindows |
|
|
|
|
? @".\\venv\\scripts\\accelerate" |
|
|
|
|
: "./venv/bin/accelerate"; |
|
|
|
|
text = text.Replace( |
|
|
|
|
"run_cmd = f'accelerate launch", |
|
|
|
|
$"run_cmd = f'{replacementAcceleratePath} launch" |
|
|
|
|
); |
|
|
|
|
await File.WriteAllTextAsync(path, text).ConfigureAwait(false); |
|
|
|
|
var text = File.ReadAllText(path); |
|
|
|
|
if (text.Contains(replacementAcceleratePath.Replace(@"\", @"\\"))) |
|
|
|
|
continue; |
|
|
|
|
|
|
|
|
|
var result = replacer.InvokeMethod("rewrite_module", text.ToPython()); |
|
|
|
|
var resultStr = result.ToString(); |
|
|
|
|
File.WriteAllText(path, resultStr); |
|
|
|
|
} |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
void HandleConsoleOutput(ProcessOutput s) |
|
|
|
|
{ |
|
|
|
|