Browse Source

Fix kohya launch not working & fix wrong package name in breadcrumb

pull/629/head
JT 7 months ago
parent
commit
aca2e7a767
  1. 2
      StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs
  2. 7
      StabilityMatrix.Core/Models/ConnectedModelInfo.cs
  3. 87
      StabilityMatrix.Core/Models/Packages/KohyaSs.cs
  4. 12
      StabilityMatrix.Core/Models/Packages/SDWebForge.cs
  5. 3
      StabilityMatrix.Core/Python/PyVenvRunner.cs

2
StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs

@ -26,7 +26,7 @@ public partial class RunningPackageViewModel : PageViewModelBase, IDisposable, I
public PackagePair RunningPackage { get; }
public ConsoleViewModel Console { get; }
public override string Title => RunningPackage.InstalledPackage.PackageName ?? "Running Package";
public override string Title => RunningPackage.InstalledPackage.DisplayName ?? "Running Package";
public override IconSource IconSource => new SymbolIconSource();
[ObservableProperty]

7
StabilityMatrix.Core/Models/ConnectedModelInfo.cs

@ -56,12 +56,19 @@ public class ConnectedModelInfo
}
public static ConnectedModelInfo? FromJson(string json)
{
try
{
return JsonSerializer.Deserialize<ConnectedModelInfo>(
json,
new JsonSerializerOptions { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }
);
}
catch (JsonException)
{
return default;
}
}
/// <summary>
/// Saves the model info to a json file in the specified directory.

87
StabilityMatrix.Core/Models/Packages/KohyaSs.cs

@ -159,90 +159,7 @@ public class KohyaSs(
Action<ProcessOutput>? onConsoleOutput
)
{
await SetupVenv(installedPackagePath).ConfigureAwait(false);
// update gui files to point to venv accelerate
await runner.RunInThreadWithLock(() =>
{
var scope = Py.CreateScope();
scope.Exec(
"""
import ast
class StringReplacer(ast.NodeTransformer):
def __init__(self, old: str, new: str, replace_count: int = -1):
self.old = old
self.new = new
self.replace_count = replace_count
def visit_Constant(self, node: ast.Constant) -> ast.Constant:
if isinstance(node.value, str) and self.old in node.value:
new_value = node.value.replace(self.old, self.new, self.replace_count)
node.value = new_value
return node
def rewrite_module(self, module_text: str) -> str:
tree = ast.parse(module_text)
tree = self.visit(tree)
return ast.unparse(tree)
"""
);
var replacementAcceleratePath = Compat.IsWindows
? @".\venv\scripts\accelerate"
: "./venv/bin/accelerate";
var replacer = scope.InvokeMethod(
"StringReplacer",
"accelerate".ToPython(),
$"{replacementAcceleratePath}".ToPython(),
1.ToPython()
);
var kohyaGuiDir = Path.Combine(installedPackagePath, "kohya_gui");
var guiDirExists = Directory.Exists(kohyaGuiDir);
var filesToUpdate = new List<string>();
if (guiDirExists)
{
filesToUpdate.AddRange(
[
"lora_gui.py",
"dreambooth_gui.py",
"textual_inversion_gui.py",
"wd14_caption_gui.py",
"finetune_gui.py"
]
);
}
else
{
filesToUpdate.AddRange(
[
"lora_gui.py",
"dreambooth_gui.py",
"textual_inversion_gui.py",
Path.Combine("library", "wd14_caption_gui.py"),
"finetune_gui.py"
]
);
}
foreach (var file in filesToUpdate)
{
var path = Path.Combine(guiDirExists ? kohyaGuiDir : installedPackagePath, file);
if (!File.Exists(path))
continue;
var text = File.ReadAllText(path);
if (text.Contains(replacementAcceleratePath.Replace(@"\", @"\\")))
continue;
var result = replacer.InvokeMethod("rewrite_module", text.ToPython());
var resultStr = result.ToString();
File.WriteAllText(path, resultStr);
}
});
var venvRunner = await SetupVenvPure(installedPackagePath).ConfigureAwait(false);
void HandleConsoleOutput(ProcessOutput s)
{
@ -262,7 +179,7 @@ public class KohyaSs(
var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
venvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
}
public override Dictionary<SharedFolderType, IReadOnlyList<string>>? SharedFolders { get; }

12
StabilityMatrix.Core/Models/Packages/SDWebForge.cs

@ -139,6 +139,13 @@ public class SDWebForge(
progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true));
var requirements = new FilePath(installLocation, "requirements_versions.txt");
var requirementsContent = await requirements.ReadAllTextAsync().ConfigureAwait(false);
if (!requirementsContent.Contains("pydantic"))
{
requirementsContent += "pydantic==1.10.15";
await requirements.WriteAllTextAsync(requirementsContent).ConfigureAwait(false);
}
var pipArgs = new PipInstallArgs();
if (torchVersion is TorchVersion.DirectMl)
{
@ -161,10 +168,7 @@ public class SDWebForge(
);
}
pipArgs = pipArgs.WithParsedFromRequirementsTxt(
await requirements.ReadAllTextAsync().ConfigureAwait(false),
excludePattern: "torch"
);
pipArgs = pipArgs.WithParsedFromRequirementsTxt(requirementsContent, excludePattern: "torch");
await venvRunner.PipInstall(pipArgs, onConsoleOutput).ConfigureAwait(false);
progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false));

3
StabilityMatrix.Core/Python/PyVenvRunner.cs

@ -545,7 +545,8 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
if (Compat.IsWindows)
{
var portableGitBin = GlobalConfig.LibraryDir.JoinDir("PortableGit", "bin");
env["PATH"] = Compat.GetEnvPathWithExtensions(portableGitBin);
var venvBin = RootPath.JoinDir(RelativeBinPath);
env["PATH"] = Compat.GetEnvPathWithExtensions(portableGitBin, venvBin);
env["GIT"] = portableGitBin.JoinFile("git.exe");
}

Loading…
Cancel
Save