diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs index e1fe3ebc..08dacc4e 100644 --- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs +++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs @@ -2,6 +2,7 @@ using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; @@ -11,6 +12,7 @@ namespace StabilityMatrix.Core.Models.Packages; public class InvokeAI : BaseGitPackage { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + private const string RelativeRootPath = "invokeai-root"; public override string Name => "InvokeAI"; public override string DisplayName { get; set; } = "InvokeAI"; @@ -46,13 +48,94 @@ public class InvokeAI : BaseGitPackage { } - // https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L86 - public override Dictionary SharedFolders => new() + public override Dictionary> SharedFolders => new() { + [SharedFolderType.StableDiffusion] = new[] + { + RelativeRootPath + "/models/sd-1/main", + RelativeRootPath + "/models/sd-2/main", + RelativeRootPath + "/models/sdxl/main", + RelativeRootPath + "/models/sdxl-refiner/main", + }, + [SharedFolderType.Lora] = new[] + { + RelativeRootPath + "/models/sd-1/lora", + RelativeRootPath + "/models/sd-2/lora", + RelativeRootPath + "/models/sdxl/lora", + RelativeRootPath + "/models/sdxl-refiner/lora", + }, + [SharedFolderType.TextualInversion] = new[] + { + RelativeRootPath + "/models/sd-1/embedding", + RelativeRootPath + "/models/sd-2/embedding", + RelativeRootPath + "/models/sdxl/embedding", + RelativeRootPath + "/models/sdxl-refiner/embedding", + }, + [SharedFolderType.VAE] = new[] + { + RelativeRootPath + "/models/sd-1/vae", + RelativeRootPath + "/models/sd-2/vae", + RelativeRootPath + "/models/sdxl/vae", + RelativeRootPath + "/models/sdxl-refiner/vae", + }, + [SharedFolderType.ControlNet] = new[] + { + RelativeRootPath + "/models/sd-1/controlnet", + RelativeRootPath + "/models/sd-2/controlnet", + RelativeRootPath + "/models/sdxl/controlnet", + RelativeRootPath + "/models/sdxl-refiner/controlnet", + }, }; + // https://github.com/invoke-ai/InvokeAI/blob/main/docs/features/CONFIGURATION.md public override List LaunchOptions => new List { + new() + { + Name = "Host", + Type = LaunchOptionType.String, + DefaultValue = "localhost", + Options = new List {"--host"} + }, + new() + { + Name = "Port", + Type = LaunchOptionType.String, + DefaultValue = "9090", + Options = new List {"--port"} + }, + new() + { + Name = "Allow Origins", + Description = "List of host names or IP addresses that are allowed to connect to the " + + "InvokeAI API in the format ['host1','host2',...]", + Type = LaunchOptionType.String, + DefaultValue = "[]", + Options = new List {"--allow-origins"} + }, + new() + { + Name = "Always use CPU", + Type = LaunchOptionType.Bool, + Options = new List {"--always_use_cpu"} + }, + new() + { + Name = "Precision", + Type = LaunchOptionType.Bool, + Options = new List + { + "--precision auto", + "--precision float16", + "--precision float32", + } + }, + new() + { + Name = "Aggressively free up GPU memory after each operation", + Type = LaunchOptionType.Bool, + Options = new List {"--free_gpu_mem"} + }, LaunchOptionDefinition.Extras }; @@ -73,6 +156,7 @@ public class InvokeAI : BaseGitPackage venvRunner.WorkingDirectory = InstallLocation; await venvRunner.Setup().ConfigureAwait(false); + venvRunner.EnvironmentVariables = GetEnvVars(InstallLocation); var gpus = HardwareHelper.IterGpuInfo().ToList(); @@ -122,6 +206,8 @@ public class InvokeAI : BaseGitPackage { await SetupVenv(installedPackagePath).ConfigureAwait(false); + VenvRunner.EnvironmentVariables = GetEnvVars(installedPackagePath); + // Launch command is for a console entry point, and not a direct script var entryPoint = await VenvRunner.GetEntryPoint(command).ConfigureAwait(false); @@ -161,4 +247,22 @@ public class InvokeAI : BaseGitPackage VenvRunner.RunDetached($"-c \"{code}\"".TrimEnd(), OnConsoleOutput, OnExit); } + + private Dictionary GetEnvVars(DirectoryPath installPath) + { + // Set additional required environment variables + var env = new Dictionary(); + if (SettingsManager.Settings.EnvironmentVariables is not null) + { + env.Update(SettingsManager.Settings.EnvironmentVariables); + } + + // Need to make subdirectory because they store config in the + // directory *above* the root directory + var root = installPath.JoinDir("invokeai_root"); + root.Create(); + env["INVOKEAI_ROOT"] = root; + + return env; + } }