Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
13 KiB

using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.RegularExpressions;
using NLog;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Helper.HardwareInfo;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Packages.Extensions;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
[Singleton(typeof(BasePackage))]
public class A3WebUI(
IGithubApiCache githubApi,
ISettingsManager settingsManager,
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper
) : BaseGitPackage(githubApi, settingsManager, downloadService, prerequisiteHelper)
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
1 year ago
public override string Name => "stable-diffusion-webui";
public override string DisplayName { get; set; } = "Stable Diffusion WebUI";
public override string Author => "AUTOMATIC1111";
public override string LicenseType => "AGPL-3.0";
1 year ago
public override string LicenseUrl =>
"https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt";
public override string Blurb => "A browser interface based on Gradio library for Stable Diffusion";
public override string LaunchCommand => "launch.py";
public override Uri PreviewImageUri =>
new("https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png");
public string RelativeArgsDefinitionScriptPath => "modules.cmd_args";
public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Recommended;
1 year ago
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
// From https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/master/models
1 year ago
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
new()
{
1 year ago
[SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" },
[SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" },
[SharedFolderType.GFPGAN] = new[] { "models/GFPGAN" },
1 year ago
[SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" },
[SharedFolderType.SwinIR] = new[] { "models/SwinIR" },
[SharedFolderType.Lora] = new[] { "models/Lora" },
[SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" },
[SharedFolderType.ApproxVAE] = new[] { "models/VAE-approx" },
[SharedFolderType.VAE] = new[] { "models/VAE" },
[SharedFolderType.DeepDanbooru] = new[] { "models/deepbooru" },
[SharedFolderType.Karlo] = new[] { "models/karlo" },
[SharedFolderType.TextualInversion] = new[] { "embeddings" },
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
[SharedFolderType.ControlNet] = new[] { "models/controlnet/ControlNet" },
[SharedFolderType.Codeformer] = new[] { "models/Codeformer" },
[SharedFolderType.LDSR] = new[] { "models/LDSR" },
[SharedFolderType.AfterDetailer] = new[] { "models/adetailer" },
[SharedFolderType.T2IAdapter] = new[] { "models/controlnet/T2IAdapter" },
[SharedFolderType.IpAdapter] = new[] { "models/controlnet/IpAdapter" },
[SharedFolderType.InvokeIpAdapters15] = new[] { "models/controlnet/DiffusersIpAdapters" },
[SharedFolderType.InvokeIpAdaptersXl] = new[] { "models/controlnet/DiffusersIpAdaptersXL" }
1 year ago
};
public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders =>
new()
{
[SharedOutputType.Extras] = new[] { "outputs/extras-images" },
[SharedOutputType.Saved] = new[] { "log/images" },
[SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" },
[SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" },
[SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" },
[SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" }
};
1 year ago
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List<LaunchOptionDefinition> LaunchOptions =>
[
1 year ago
new()
{
1 year ago
Name = "Host",
Type = LaunchOptionType.String,
DefaultValue = "localhost",
Options = ["--server-name"]
},
1 year ago
new()
{
Name = "Port",
Type = LaunchOptionType.String,
DefaultValue = "7860",
Options = ["--port"]
1 year ago
},
new()
{
Name = "VRAM",
Type = LaunchOptionType.Bool,
InitialValue = HardwareHelper.IterGpuInfo().Select(gpu => gpu.MemoryLevel).Max() switch
1 year ago
{
MemoryLevel.Low => "--lowvram",
MemoryLevel.Medium => "--medvram",
1 year ago
_ => null
},
Options = ["--lowvram", "--medvram", "--medvram-sdxl"]
1 year ago
},
new()
{
Name = "Xformers",
Type = LaunchOptionType.Bool,
InitialValue = HardwareHelper.HasNvidiaGpu(),
Options = ["--xformers"]
1 year ago
},
new()
{
Name = "API",
Type = LaunchOptionType.Bool,
InitialValue = true,
Options = ["--api"]
1 year ago
},
new()
{
Name = "Auto Launch Web UI",
Type = LaunchOptionType.Bool,
InitialValue = false,
Options = ["--autolaunch"]
},
new()
1 year ago
{
Name = "Skip Torch CUDA Check",
Type = LaunchOptionType.Bool,
InitialValue = !HardwareHelper.HasNvidiaGpu(),
Options = ["--skip-torch-cuda-test"]
1 year ago
},
new()
{
Name = "Skip Python Version Check",
Type = LaunchOptionType.Bool,
InitialValue = true,
Options = ["--skip-python-version-check"]
1 year ago
},
new()
{
Name = "No Half",
Type = LaunchOptionType.Bool,
Description = "Do not switch the model to 16-bit floats",
InitialValue =
HardwareHelper.PreferRocm() || HardwareHelper.PreferDirectML() || Compat.IsMacOS,
Options = ["--no-half"]
1 year ago
},
new()
{
Name = "Skip SD Model Download",
Type = LaunchOptionType.Bool,
InitialValue = false,
Options = ["--no-download-sd-model"]
},
new()
{
Name = "Skip Install",
Type = LaunchOptionType.Bool,
Options = ["--skip-install"]
},
1 year ago
LaunchOptionDefinition.Extras
];
1 year ago
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
public override IEnumerable<TorchVersion> AvailableTorchVersions =>
new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm, TorchVersion.Mps };
public override string MainBranch => "master";
public override string OutputFolderName => "outputs";
public override IPackageExtensionManager ExtensionManager => new A3WebUiExtensionManager(this);
1 year ago
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
SharedFolderMethod selectedSharedFolderMethod,
DownloadPackageVersionOptions versionOptions,
IProgress<ProgressReport>? progress = null,
Action<ProcessOutput>? onConsoleOutput = null
1 year ago
)
{
progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
var venvPath = Path.Combine(installLocation, "venv");
var exists = Directory.Exists(venvPath);
await using var venvRunner = new PyVenvRunner(venvPath);
venvRunner.WorkingDirectory = installLocation;
await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false);
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true));
var requirements = new FilePath(installLocation, "requirements_versions.txt");
var pipArgs = new PipInstallArgs()
.WithTorch("==2.1.2")
.WithTorchVision("==0.16.2")
.WithTorchExtraIndex(
torchVersion switch
{
TorchVersion.Cpu => "cpu",
TorchVersion.Cuda => "cu121",
TorchVersion.Rocm => "rocm5.6",
TorchVersion.Mps => "nightly/cpu",
_ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
}
)
.WithParsedFromRequirementsTxt(
await requirements.ReadAllTextAsync().ConfigureAwait(false),
excludePattern: "torch"
);
if (torchVersion == TorchVersion.Cuda)
{
pipArgs = pipArgs.WithXFormers("==0.0.23.post1");
}
// v1.6.0 needs a httpx qualifier to fix a gradio issue
if (versionOptions.VersionTag?.Contains("1.6.0") ?? false)
{
pipArgs = pipArgs.AddArg("httpx==0.24.1");
}
// Add jsonmerge to fix https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/12482
pipArgs = pipArgs.AddArg("jsonmerge");
if (exists)
{
pipArgs = pipArgs.AddArg("--upgrade");
pipArgs = pipArgs.AddArg("--force-reinstall");
}
await venvRunner.PipInstall(pipArgs, onConsoleOutput).ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Updating configuration", isIndeterminate: true));
// Create and add {"show_progress_type": "TAESD"} to config.json
// Only add if the file doesn't exist
var configPath = Path.Combine(installLocation, "config.json");
if (!File.Exists(configPath))
{
1 year ago
var config = new JsonObject { { "show_progress_type", "TAESD" } };
await File.WriteAllTextAsync(configPath, config.ToString()).ConfigureAwait(false);
}
progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false));
}
1 year ago
public override async Task RunPackage(
string installedPackagePath,
string command,
string arguments,
Action<ProcessOutput>? onConsoleOutput
1 year ago
)
{
await SetupVenv(installedPackagePath).ConfigureAwait(false);
void HandleConsoleOutput(ProcessOutput s)
{
onConsoleOutput?.Invoke(s);
1 year ago
if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase))
return;
1 year ago
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
var match = regex.Match(s.Text);
if (!match.Success)
return;
1 year ago
WebUrl = match.Value;
OnStartupComplete(WebUrl);
}
var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
}
private class A3WebUiExtensionManager(A3WebUI package)
: GitPackageExtensionManager(package.PrerequisiteHelper)
{
public override string RelativeInstallDirectory => "extensions";
public override IEnumerable<ExtensionManifest> DefaultManifests =>
[
new ExtensionManifest(
new Uri(
"https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json"
)
)
];
public override async Task<IEnumerable<PackageExtension>> GetManifestExtensionsAsync(
ExtensionManifest manifest,
CancellationToken cancellationToken = default
)
{
try
{
// Get json
var content = await package
.DownloadService.GetContentAsync(manifest.Uri.ToString(), cancellationToken)
.ConfigureAwait(false);
// Parse json
var jsonManifest = JsonSerializer.Deserialize<A1111ExtensionManifest>(
content,
A1111ExtensionManifestSerializerContext.Default.Options
);
return jsonManifest?.GetPackageExtensions() ?? Enumerable.Empty<PackageExtension>();
}
catch (Exception e)
{
Logger.Error(e, "Failed to get extensions from manifest");
return Enumerable.Empty<PackageExtension>();
}
}
}
}