diff --git a/StabilityMatrix.Avalonia/Services/RunningPackageService.cs b/StabilityMatrix.Avalonia/Services/RunningPackageService.cs index 89f0d903..d3cb114f 100644 --- a/StabilityMatrix.Avalonia/Services/RunningPackageService.cs +++ b/StabilityMatrix.Avalonia/Services/RunningPackageService.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Immutable; -using System.Collections.ObjectModel; using System.Linq; using System.Threading.Tasks; using Avalonia.Controls.Notifications; @@ -10,7 +9,6 @@ using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.ViewModels; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; -using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Factory; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.FileInterfaces; @@ -114,8 +112,6 @@ public partial class RunningPackageService( await basePackage.RunPackage(packagePath, command, userArgsString, o => console.Post(o)); var runningPackage = new PackagePair(installedPackage, basePackage); - EventManager.Instance.OnRunningPackageStatusChanged(runningPackage); - var viewModel = new RunningPackageViewModel( settingsManager, notificationService, diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs index fafd1206..4f271517 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using System.Threading.Tasks; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; @@ -30,6 +31,7 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa private readonly ISettingsManager settingsManager; private readonly INavigationService navigationService; private readonly IPackageFactory packageFactory; + private readonly RunningPackageService runningPackageService; [ObservableProperty] private string title = "Hello"; @@ -58,12 +60,14 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa public InferenceConnectionHelpViewModel( ISettingsManager settingsManager, INavigationService navigationService, - IPackageFactory packageFactory + IPackageFactory packageFactory, + RunningPackageService runningPackageService ) { this.settingsManager = settingsManager; this.navigationService = navigationService; this.packageFactory = packageFactory; + this.runningPackageService = runningPackageService; // Get comfy type installed packages var comfyPackages = this.settingsManager.Settings.InstalledPackages.Where( @@ -122,14 +126,11 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa /// Request launch of the selected package /// [RelayCommand] - private void LaunchSelectedPackage() + private async Task LaunchSelectedPackage() { - if (SelectedPackage?.Id is { } id) + if (SelectedPackage is not null) { - Dispatcher.UIThread.Post(() => - { - EventManager.Instance.OnPackageLaunchRequested(id); - }); + await runningPackageService.StartPackage(SelectedPackage); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs index c36fa128..42378ce2 100644 --- a/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs @@ -1,6 +1,8 @@ using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.Collections.ObjectModel; +using System.Collections.Specialized; using System.Linq; using System.Reactive.Linq; using System.Text.Json; @@ -51,6 +53,8 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable private readonly ServiceManager vmFactory; private readonly IModelIndexService modelIndexService; private readonly ILiteDbContext liteDbContext; + private readonly RunningPackageService runningPackageService; + private Guid? selectedPackageId; public override string Title => "Inference"; public override IconSource IconSource => @@ -86,6 +90,8 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable public bool IsComfyRunning => RunningPackage?.BasePackage is ComfyUI; + private IDisposable? onStartupComplete; + public InferenceViewModel( ServiceManager vmFactory, INotificationService notificationService, @@ -93,6 +99,7 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable ISettingsManager settingsManager, IModelIndexService modelIndexService, ILiteDbContext liteDbContext, + RunningPackageService runningPackageService, SharedState sharedState ) { @@ -101,12 +108,13 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable this.settingsManager = settingsManager; this.modelIndexService = modelIndexService; this.liteDbContext = liteDbContext; + this.runningPackageService = runningPackageService; ClientManager = inferenceClientManager; SharedState = sharedState; // Keep RunningPackage updated with the current package pair - EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged; + runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged; // "Send to Inference" EventManager.Instance.InferenceTextToImageRequested += OnInferenceTextToImageRequested; @@ -118,54 +126,77 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService); } + private void DisconnectFromComfy() + { + RunningPackage = null; + + // Cancel any pending connection + if (ConnectCancelCommand.CanExecute(null)) + { + ConnectCancelCommand.Execute(null); + } + onStartupComplete?.Dispose(); + onStartupComplete = null; + IsWaitingForConnection = false; + + // Disconnect + Logger.Trace("On package close - disconnecting"); + DisconnectCommand.Execute(null); + } + /// /// Updates the RunningPackage property when the running package changes. /// Also starts a connection to the backend if a new ComfyUI package is running. /// And disconnects if the package is closed. /// - private void OnRunningPackageStatusChanged(object? sender, RunningPackageStatusChangedEventArgs e) + private void RunningPackagesOnCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e) { - RunningPackage = e.CurrentPackagePair; + if ( + e.NewItems?.OfType>().Select(x => x.Value) + is not { } newItems + ) + { + if (RunningPackage != null) + { + DisconnectFromComfy(); + } + return; + } - IDisposable? onStartupComplete = null; + var comfyViewModel = newItems.FirstOrDefault( + vm => + vm.RunningPackage.InstalledPackage.Id == selectedPackageId + || vm.RunningPackage.BasePackage is ComfyUI + ); - Dispatcher.UIThread.Post(() => + if (comfyViewModel is null && RunningPackage?.BasePackage is ComfyUI) { - if (e.CurrentPackagePair?.BasePackage is ComfyUI package) - { - IsWaitingForConnection = true; - onStartupComplete = Observable - .FromEventPattern(package, nameof(package.StartupComplete)) - .Take(1) - .Subscribe(_ => + DisconnectFromComfy(); + } + else if (comfyViewModel != null && RunningPackage == null) + { + IsWaitingForConnection = true; + RunningPackage = comfyViewModel.RunningPackage; + onStartupComplete = Observable + .FromEventPattern( + comfyViewModel.RunningPackage.BasePackage, + nameof(comfyViewModel.RunningPackage.BasePackage.StartupComplete) + ) + .Take(1) + .Subscribe(_ => + { + Dispatcher.UIThread.Post(() => { - Dispatcher.UIThread.Post(() => + if (ConnectCommand.CanExecute(null)) { - if (ConnectCommand.CanExecute(null)) - { - Logger.Trace("On package launch - starting connection"); - ConnectCommand.Execute(null); - } - IsWaitingForConnection = false; - }); - }); - } - else - { - // Cancel any pending connection - if (ConnectCancelCommand.CanExecute(null)) - { - ConnectCancelCommand.Execute(null); - } - onStartupComplete?.Dispose(); - onStartupComplete = null; - IsWaitingForConnection = false; + Logger.Trace("On package launch - starting connection"); + ConnectCommand.Execute(null); + } - // Disconnect - Logger.Trace("On package close - disconnecting"); - DisconnectCommand.Execute(null); - } - }); + IsWaitingForConnection = false; + }); + }); + } } public override void OnLoaded() @@ -390,7 +421,12 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable private async Task ShowConnectionHelp() { var vm = vmFactory.Get(); - await vm.CreateDialog().ShowAsync(); + var result = await vm.CreateDialog().ShowAsync(); + + if (result != ContentDialogResult.Primary) + return; + + selectedPackageId = vm.SelectedPackage?.Id; } /// diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs index 2e8bf0a9..614fef81 100644 --- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Collections.Specialized; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -36,15 +37,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.PackageManager; [ManagedService] [Transient] -public partial class PackageCardViewModel( - ILogger logger, - IPackageFactory packageFactory, - INotificationService notificationService, - ISettingsManager settingsManager, - INavigationService navigationService, - ServiceManager vmFactory, - RunningPackageService runningPackageService -) : ProgressViewModel +public partial class PackageCardViewModel : ProgressViewModel { private string webUiUrl = string.Empty; @@ -93,6 +86,59 @@ public partial class PackageCardViewModel( [ObservableProperty] private bool showWebUiButton; + private readonly ILogger logger; + private readonly IPackageFactory packageFactory; + private readonly INotificationService notificationService; + private readonly ISettingsManager settingsManager; + private readonly INavigationService navigationService; + private readonly ServiceManager vmFactory; + private readonly RunningPackageService runningPackageService; + + /// + public PackageCardViewModel( + ILogger logger, + IPackageFactory packageFactory, + INotificationService notificationService, + ISettingsManager settingsManager, + INavigationService navigationService, + ServiceManager vmFactory, + RunningPackageService runningPackageService + ) + { + this.logger = logger; + this.packageFactory = packageFactory; + this.notificationService = notificationService; + this.settingsManager = settingsManager; + this.navigationService = navigationService; + this.vmFactory = vmFactory; + this.runningPackageService = runningPackageService; + + runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged; + } + + private void RunningPackagesOnCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e) + { + if ( + e.NewItems?.OfType>().Select(x => x.Value) + is not { } newItems + ) + return; + + var runningViewModel = newItems.FirstOrDefault( + x => x.RunningPackage.InstalledPackage.Id == Package?.Id + ); + if (runningViewModel is not null) + { + IsRunning = true; + runningViewModel.RunningPackage.BasePackage.Exited += BasePackageOnExited; + runningViewModel.RunningPackage.BasePackage.StartupComplete += RunningPackageOnStartupComplete; + } + else if (runningViewModel is null && IsRunning) + { + IsRunning = false; + } + } + partial void OnPackageChanged(InstalledPackage? value) { if (string.IsNullOrWhiteSpace(value?.PackageName)) @@ -232,7 +278,6 @@ public partial class PackageCardViewModel( private void BasePackageOnExited(object? sender, int exitCode) { - EventManager.Instance.OnRunningPackageStatusChanged(null); Dispatcher .UIThread.InvokeAsync(async () => { diff --git a/StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs index 2c188633..2758269f 100644 --- a/StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs @@ -1,7 +1,5 @@ using System; using System.Threading.Tasks; -using Avalonia.Threading; -using AvaloniaEdit; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; @@ -19,7 +17,7 @@ using SymbolIconSource = FluentIcons.Avalonia.Fluent.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(ConsoleOutputPage))] -public partial class RunningPackageViewModel : PageViewModelBase +public partial class RunningPackageViewModel : PageViewModelBase, IDisposable, IAsyncDisposable { private readonly INotificationService notificationService; private string webUiUrl = string.Empty; @@ -51,7 +49,7 @@ public partial class RunningPackageViewModel : PageViewModelBase Console = console; Console.Document.LineCountChanged += DocumentOnLineCountChanged; RunningPackage.BasePackage.StartupComplete += BasePackageOnStartupComplete; - runningPackage.BasePackage.Exited += BasePackageOnExited; + RunningPackage.BasePackage.Exited += BasePackageOnExited; settingsManager.RelayPropertyFor( this, @@ -65,6 +63,9 @@ public partial class RunningPackageViewModel : PageViewModelBase { IsRunning = false; ShowWebUiButton = false; + Console.Document.LineCountChanged -= DocumentOnLineCountChanged; + RunningPackage.BasePackage.StartupComplete -= BasePackageOnStartupComplete; + RunningPackage.BasePackage.Exited -= BasePackageOnExited; } private void BasePackageOnStartupComplete(object? sender, string url) @@ -123,4 +124,14 @@ public partial class RunningPackageViewModel : PageViewModelBase } } } + + public void Dispose() + { + Console.Dispose(); + } + + public async ValueTask DisposeAsync() + { + await Console.DisposeAsync(); + } } diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 0e959e89..3e1ccee2 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -2,6 +2,7 @@ using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Packages; +using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Helper.Factory; @@ -13,6 +14,7 @@ public class PackageFactory : IPackageFactory private readonly ISettingsManager settingsManager; private readonly IDownloadService downloadService; private readonly IPrerequisiteHelper prerequisiteHelper; + private readonly IPyRunner pyRunner; /// /// Mapping of package.Name to package @@ -24,13 +26,15 @@ public class PackageFactory : IPackageFactory IGithubApiCache githubApiCache, ISettingsManager settingsManager, IDownloadService downloadService, - IPrerequisiteHelper prerequisiteHelper + IPrerequisiteHelper prerequisiteHelper, + IPyRunner pyRunner ) { this.githubApiCache = githubApiCache; this.settingsManager = settingsManager; this.downloadService = downloadService; this.prerequisiteHelper = prerequisiteHelper; + this.pyRunner = pyRunner; this.basePackages = basePackages.ToDictionary(x => x.Name); } @@ -44,7 +48,44 @@ public class PackageFactory : IPackageFactory => new A3WebUI(githubApiCache, settingsManager, downloadService, prerequisiteHelper), "Fooocus-ControlNet-SDXL" => new FocusControlNet(githubApiCache, settingsManager, downloadService, prerequisiteHelper), - _ => throw new ArgumentOutOfRangeException() + "Fooocus-MRE" + => new FooocusMre(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "InvokeAI" => new InvokeAI(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "kohya_ss" + => new KohyaSs( + githubApiCache, + settingsManager, + downloadService, + prerequisiteHelper, + pyRunner + ), + "OneTrainer" + => new OneTrainer(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "RuinedFooocus" + => new RuinedFooocus(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "stable-diffusion-webui-forge" + => new SDWebForge(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "stable-diffusion-webui-directml" + => new StableDiffusionDirectMl( + githubApiCache, + settingsManager, + downloadService, + prerequisiteHelper + ), + "stable-diffusion-webui-ux" + => new StableDiffusionUx( + githubApiCache, + settingsManager, + downloadService, + prerequisiteHelper + ), + "StableSwarmUI" + => new StableSwarm(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "automatic" + => new VladAutomatic(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + "voltaML-fast-stable-diffusion" + => new VoltaML(githubApiCache, settingsManager, downloadService, prerequisiteHelper), + _ => throw new ArgumentOutOfRangeException(nameof(installedPackage)) }; }