You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
279 lines
11 KiB
279 lines
11 KiB
using System.Text.RegularExpressions; |
|
using Python.Runtime; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Helper; |
|
using StabilityMatrix.Core.Helper.Cache; |
|
using StabilityMatrix.Core.Helper.HardwareInfo; |
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
using StabilityMatrix.Core.Models.Progress; |
|
using StabilityMatrix.Core.Processes; |
|
using StabilityMatrix.Core.Python; |
|
using StabilityMatrix.Core.Services; |
|
|
|
namespace StabilityMatrix.Core.Models.Packages; |
|
|
|
[Singleton(typeof(BasePackage))] |
|
public class KohyaSs( |
|
IGithubApiCache githubApi, |
|
ISettingsManager settingsManager, |
|
IDownloadService downloadService, |
|
IPrerequisiteHelper prerequisiteHelper, |
|
IPyRunner runner |
|
) : BaseGitPackage(githubApi, settingsManager, downloadService, prerequisiteHelper) |
|
{ |
|
public override string Name => "kohya_ss"; |
|
public override string DisplayName { get; set; } = "kohya_ss"; |
|
public override string Author => "bmaltais"; |
|
public override string Blurb => "A Windows-focused Gradio GUI for Kohya's Stable Diffusion trainers"; |
|
public override string LicenseType => "Apache-2.0"; |
|
public override string LicenseUrl => "https://github.com/bmaltais/kohya_ss/blob/master/LICENSE.md"; |
|
public override string LaunchCommand => "kohya_gui.py"; |
|
|
|
public override Uri PreviewImageUri => |
|
new( |
|
"https://camo.githubusercontent.com/5154eea62c113d5c04393e51a0d0f76ef25a723aad29d256dcc85ead1961cd41/68747470733a2f2f696d672e796f75747562652e636f6d2f76692f6b35696d713031757655592f302e6a7067" |
|
); |
|
public override string OutputFolderName => string.Empty; |
|
|
|
public override bool IsCompatible => HardwareHelper.HasNvidiaGpu(); |
|
|
|
public override TorchVersion GetRecommendedTorchVersion() => TorchVersion.Cuda; |
|
|
|
public override string Disclaimer => |
|
"Nvidia GPU with at least 8GB VRAM is recommended. May be unstable on Linux."; |
|
|
|
public override PackageDifficulty InstallerSortOrder => PackageDifficulty.UltraNightmare; |
|
public override PackageType PackageType => PackageType.SdTraining; |
|
public override bool OfferInOneClickInstaller => false; |
|
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.None; |
|
public override IEnumerable<TorchVersion> AvailableTorchVersions => [TorchVersion.Cuda]; |
|
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => |
|
new[] { SharedFolderMethod.None }; |
|
public override IEnumerable<PackagePrerequisite> Prerequisites => |
|
base.Prerequisites.Concat([PackagePrerequisite.Tkinter]); |
|
|
|
public override List<LaunchOptionDefinition> LaunchOptions => |
|
[ |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Listen Address", |
|
Type = LaunchOptionType.String, |
|
DefaultValue = "127.0.0.1", |
|
Options = ["--listen"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Port", |
|
Type = LaunchOptionType.String, |
|
Options = ["--port"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Username", |
|
Type = LaunchOptionType.String, |
|
Options = ["--username"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Password", |
|
Type = LaunchOptionType.String, |
|
Options = ["--password"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Auto-Launch Browser", |
|
Type = LaunchOptionType.Bool, |
|
Options = ["--inbrowser"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Share", |
|
Type = LaunchOptionType.Bool, |
|
Options = ["--share"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Headless", |
|
Type = LaunchOptionType.Bool, |
|
Options = ["--headless"] |
|
}, |
|
new LaunchOptionDefinition |
|
{ |
|
Name = "Language", |
|
Type = LaunchOptionType.String, |
|
Options = ["--language"] |
|
}, |
|
LaunchOptionDefinition.Extras |
|
]; |
|
|
|
public override async Task InstallPackage( |
|
string installLocation, |
|
TorchVersion torchVersion, |
|
SharedFolderMethod selectedSharedFolderMethod, |
|
DownloadPackageVersionOptions versionOptions, |
|
IProgress<ProgressReport>? progress = null, |
|
Action<ProcessOutput>? onConsoleOutput = null |
|
) |
|
{ |
|
progress?.Report(new ProgressReport(-1f, "Updating submodules", isIndeterminate: true)); |
|
await PrerequisiteHelper |
|
.RunGit( |
|
["submodule", "update", "--init", "--recursive", "--quiet"], |
|
onConsoleOutput, |
|
installLocation |
|
) |
|
.ConfigureAwait(false); |
|
|
|
progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); |
|
// Setup venv |
|
await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); |
|
venvRunner.WorkingDirectory = installLocation; |
|
await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false); |
|
|
|
// Extra dep needed before running setup since v23.0.x |
|
await venvRunner.PipInstall("packaging").ConfigureAwait(false); |
|
|
|
if (Compat.IsWindows) |
|
{ |
|
var setupSmPath = Path.Combine(installLocation, "setup", "setup_sm.py"); |
|
var setupText = """ |
|
import setup_windows |
|
import setup_common |
|
|
|
setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=False) |
|
setup_windows.sync_bits_and_bytes_files() |
|
setup_common.configure_accelerate(run_accelerate=False) |
|
"""; |
|
await File.WriteAllTextAsync(setupSmPath, setupText).ConfigureAwait(false); |
|
|
|
// Install |
|
await venvRunner.CustomInstall("setup/setup_sm.py", onConsoleOutput).ConfigureAwait(false); |
|
await venvRunner.PipInstall("bitsandbytes-windows").ConfigureAwait(false); |
|
} |
|
else if (Compat.IsLinux) |
|
{ |
|
await venvRunner |
|
.CustomInstall( |
|
"setup/setup_linux.py --platform-requirements-file=requirements_linux.txt --no_run_accelerate", |
|
onConsoleOutput |
|
) |
|
.ConfigureAwait(false); |
|
} |
|
} |
|
|
|
public override async Task RunPackage( |
|
string installedPackagePath, |
|
string command, |
|
string arguments, |
|
Action<ProcessOutput>? onConsoleOutput |
|
) |
|
{ |
|
await SetupVenv(installedPackagePath).ConfigureAwait(false); |
|
|
|
// update gui files to point to venv accelerate |
|
await runner.RunInThreadWithLock(() => |
|
{ |
|
var scope = Py.CreateScope(); |
|
scope.Exec( |
|
""" |
|
import ast |
|
|
|
class StringReplacer(ast.NodeTransformer): |
|
def __init__(self, old: str, new: str, replace_count: int = -1): |
|
self.old = old |
|
self.new = new |
|
self.replace_count = replace_count |
|
|
|
def visit_Constant(self, node: ast.Constant) -> ast.Constant: |
|
if isinstance(node.value, str) and self.old in node.value: |
|
new_value = node.value.replace(self.old, self.new, self.replace_count) |
|
node.value = new_value |
|
return node |
|
|
|
def rewrite_module(self, module_text: str) -> str: |
|
tree = ast.parse(module_text) |
|
tree = self.visit(tree) |
|
return ast.unparse(tree) |
|
""" |
|
); |
|
|
|
var replacementAcceleratePath = Compat.IsWindows |
|
? @".\venv\scripts\accelerate" |
|
: "./venv/bin/accelerate"; |
|
|
|
var replacer = scope.InvokeMethod( |
|
"StringReplacer", |
|
"accelerate".ToPython(), |
|
$"{replacementAcceleratePath}".ToPython(), |
|
1.ToPython() |
|
); |
|
|
|
var kohyaGuiDir = Path.Combine(installedPackagePath, "kohya_gui"); |
|
var guiDirExists = Directory.Exists(kohyaGuiDir); |
|
var filesToUpdate = new List<string>(); |
|
|
|
if (guiDirExists) |
|
{ |
|
filesToUpdate.AddRange( |
|
[ |
|
"lora_gui.py", |
|
"dreambooth_gui.py", |
|
"textual_inversion_gui.py", |
|
"wd14_caption_gui.py", |
|
"finetune_gui.py" |
|
] |
|
); |
|
} |
|
else |
|
{ |
|
filesToUpdate.AddRange( |
|
[ |
|
"lora_gui.py", |
|
"dreambooth_gui.py", |
|
"textual_inversion_gui.py", |
|
Path.Combine("library", "wd14_caption_gui.py"), |
|
"finetune_gui.py" |
|
] |
|
); |
|
} |
|
|
|
foreach (var file in filesToUpdate) |
|
{ |
|
var path = Path.Combine(guiDirExists ? kohyaGuiDir : installedPackagePath, file); |
|
var text = File.ReadAllText(path); |
|
if (text.Contains(replacementAcceleratePath.Replace(@"\", @"\\"))) |
|
continue; |
|
|
|
var result = replacer.InvokeMethod("rewrite_module", text.ToPython()); |
|
var resultStr = result.ToString(); |
|
File.WriteAllText(path, resultStr); |
|
} |
|
}); |
|
|
|
void HandleConsoleOutput(ProcessOutput s) |
|
{ |
|
onConsoleOutput?.Invoke(s); |
|
|
|
if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase)) |
|
return; |
|
|
|
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); |
|
var match = regex.Match(s.Text); |
|
if (!match.Success) |
|
return; |
|
|
|
WebUrl = match.Value; |
|
OnStartupComplete(WebUrl); |
|
} |
|
|
|
var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; |
|
|
|
VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); |
|
} |
|
|
|
public override Dictionary<SharedFolderType, IReadOnlyList<string>>? SharedFolders { get; } |
|
public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders { get; } |
|
|
|
public override string MainBranch => "master"; |
|
}
|
|
|