using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Helper.Factory;
[Singleton(typeof(IPackageFactory))]
public class PackageFactory : IPackageFactory
{
private readonly IGithubApiCache githubApiCache;
private readonly ISettingsManager settingsManager;
private readonly IDownloadService downloadService;
private readonly IPrerequisiteHelper prerequisiteHelper;
private readonly IPyRunner pyRunner;
///
/// Mapping of package.Name to package
///
private readonly Dictionary basePackages;
public PackageFactory(
IEnumerable basePackages,
IGithubApiCache githubApiCache,
ISettingsManager settingsManager,
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner
)
{
this.githubApiCache = githubApiCache;
this.settingsManager = settingsManager;
this.downloadService = downloadService;
this.prerequisiteHelper = prerequisiteHelper;
this.pyRunner = pyRunner;
this.basePackages = basePackages.ToDictionary(x => x.Name);
}
public BasePackage GetNewBasePackage(InstalledPackage installedPackage)
{
return installedPackage.PackageName switch
{
"ComfyUI" => new ComfyUI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"Fooocus" => new Fooocus(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui"
=> new A3WebUI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"Fooocus-ControlNet-SDXL"
=> new FocusControlNet(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"Fooocus-MRE"
=> new FooocusMre(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"InvokeAI" => new InvokeAI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"kohya_ss"
=> new KohyaSs(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper,
pyRunner
),
"OneTrainer"
=> new OneTrainer(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"RuinedFooocus"
=> new RuinedFooocus(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui-forge"
=> new SDWebForge(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui-directml"
=> new StableDiffusionDirectMl(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper
),
"stable-diffusion-webui-ux"
=> new StableDiffusionUx(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper
),
"StableSwarmUI"
=> new StableSwarm(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"automatic"
=> new VladAutomatic(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"voltaML-fast-stable-diffusion"
=> new VoltaML(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
_ => throw new ArgumentOutOfRangeException(nameof(installedPackage))
};
}
public IEnumerable GetAllAvailablePackages()
{
return basePackages.Values.OrderBy(p => p.InstallerSortOrder).ThenBy(p => p.DisplayName);
}
public BasePackage? FindPackageByName(string? packageName)
{
return packageName == null ? null : basePackages.GetValueOrDefault(packageName);
}
public BasePackage? this[string packageName] => basePackages[packageName];
///
public PackagePair? GetPackagePair(InstalledPackage? installedPackage)
{
if (installedPackage?.PackageName is not { } packageName)
return null;
return !basePackages.TryGetValue(packageName, out var basePackage)
? null
: new PackagePair(installedPackage, basePackage);
}
public IEnumerable GetPackagesByType(PackageType packageType) =>
basePackages.Values.Where(p => p.PackageType == packageType);
}