Browse Source

fix inference launch connection stuff to work with multi-package

pull/629/head
JT 9 months ago
parent
commit
45f48ce36f
  1. 4
      StabilityMatrix.Avalonia/Services/RunningPackageService.cs
  2. 15
      StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs
  3. 84
      StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
  4. 65
      StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
  5. 19
      StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs
  6. 45
      StabilityMatrix.Core/Helper/Factory/PackageFactory.cs

4
StabilityMatrix.Avalonia/Services/RunningPackageService.cs

@ -1,6 +1,5 @@
using System; using System;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls.Notifications; using Avalonia.Controls.Notifications;
@ -10,7 +9,6 @@ using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.ViewModels; using StabilityMatrix.Avalonia.ViewModels;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory; using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.FileInterfaces;
@ -114,8 +112,6 @@ public partial class RunningPackageService(
await basePackage.RunPackage(packagePath, command, userArgsString, o => console.Post(o)); await basePackage.RunPackage(packagePath, command, userArgsString, o => console.Post(o));
var runningPackage = new PackagePair(installedPackage, basePackage); var runningPackage = new PackagePair(installedPackage, basePackage);
EventManager.Instance.OnRunningPackageStatusChanged(runningPackage);
var viewModel = new RunningPackageViewModel( var viewModel = new RunningPackageViewModel(
settingsManager, settingsManager,
notificationService, notificationService,

15
StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs

@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Linq; using System.Linq;
using System.Threading.Tasks;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
@ -30,6 +31,7 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly INavigationService<MainWindowViewModel> navigationService; private readonly INavigationService<MainWindowViewModel> navigationService;
private readonly IPackageFactory packageFactory; private readonly IPackageFactory packageFactory;
private readonly RunningPackageService runningPackageService;
[ObservableProperty] [ObservableProperty]
private string title = "Hello"; private string title = "Hello";
@ -58,12 +60,14 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa
public InferenceConnectionHelpViewModel( public InferenceConnectionHelpViewModel(
ISettingsManager settingsManager, ISettingsManager settingsManager,
INavigationService<MainWindowViewModel> navigationService, INavigationService<MainWindowViewModel> navigationService,
IPackageFactory packageFactory IPackageFactory packageFactory,
RunningPackageService runningPackageService
) )
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.navigationService = navigationService; this.navigationService = navigationService;
this.packageFactory = packageFactory; this.packageFactory = packageFactory;
this.runningPackageService = runningPackageService;
// Get comfy type installed packages // Get comfy type installed packages
var comfyPackages = this.settingsManager.Settings.InstalledPackages.Where( var comfyPackages = this.settingsManager.Settings.InstalledPackages.Where(
@ -122,14 +126,11 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa
/// Request launch of the selected package /// Request launch of the selected package
/// </summary> /// </summary>
[RelayCommand] [RelayCommand]
private void LaunchSelectedPackage() private async Task LaunchSelectedPackage()
{ {
if (SelectedPackage?.Id is { } id) if (SelectedPackage is not null)
{ {
Dispatcher.UIThread.Post(() => await runningPackageService.StartPackage(SelectedPackage);
{
EventManager.Instance.OnPackageLaunchRequested(id);
});
} }
} }

84
StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs

@ -1,6 +1,8 @@
using System; using System;
using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.Collections.Specialized;
using System.Linq; using System.Linq;
using System.Reactive.Linq; using System.Reactive.Linq;
using System.Text.Json; using System.Text.Json;
@ -51,6 +53,8 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
private readonly ServiceManager<ViewModelBase> vmFactory; private readonly ServiceManager<ViewModelBase> vmFactory;
private readonly IModelIndexService modelIndexService; private readonly IModelIndexService modelIndexService;
private readonly ILiteDbContext liteDbContext; private readonly ILiteDbContext liteDbContext;
private readonly RunningPackageService runningPackageService;
private Guid? selectedPackageId;
public override string Title => "Inference"; public override string Title => "Inference";
public override IconSource IconSource => public override IconSource IconSource =>
@ -86,6 +90,8 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
public bool IsComfyRunning => RunningPackage?.BasePackage is ComfyUI; public bool IsComfyRunning => RunningPackage?.BasePackage is ComfyUI;
private IDisposable? onStartupComplete;
public InferenceViewModel( public InferenceViewModel(
ServiceManager<ViewModelBase> vmFactory, ServiceManager<ViewModelBase> vmFactory,
INotificationService notificationService, INotificationService notificationService,
@ -93,6 +99,7 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
ISettingsManager settingsManager, ISettingsManager settingsManager,
IModelIndexService modelIndexService, IModelIndexService modelIndexService,
ILiteDbContext liteDbContext, ILiteDbContext liteDbContext,
RunningPackageService runningPackageService,
SharedState sharedState SharedState sharedState
) )
{ {
@ -101,12 +108,13 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.modelIndexService = modelIndexService; this.modelIndexService = modelIndexService;
this.liteDbContext = liteDbContext; this.liteDbContext = liteDbContext;
this.runningPackageService = runningPackageService;
ClientManager = inferenceClientManager; ClientManager = inferenceClientManager;
SharedState = sharedState; SharedState = sharedState;
// Keep RunningPackage updated with the current package pair // Keep RunningPackage updated with the current package pair
EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged; runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged;
// "Send to Inference" // "Send to Inference"
EventManager.Instance.InferenceTextToImageRequested += OnInferenceTextToImageRequested; EventManager.Instance.InferenceTextToImageRequested += OnInferenceTextToImageRequested;
@ -118,24 +126,62 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService); MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService);
} }
private void DisconnectFromComfy()
{
RunningPackage = null;
// Cancel any pending connection
if (ConnectCancelCommand.CanExecute(null))
{
ConnectCancelCommand.Execute(null);
}
onStartupComplete?.Dispose();
onStartupComplete = null;
IsWaitingForConnection = false;
// Disconnect
Logger.Trace("On package close - disconnecting");
DisconnectCommand.Execute(null);
}
/// <summary> /// <summary>
/// Updates the RunningPackage property when the running package changes. /// Updates the RunningPackage property when the running package changes.
/// Also starts a connection to the backend if a new ComfyUI package is running. /// Also starts a connection to the backend if a new ComfyUI package is running.
/// And disconnects if the package is closed. /// And disconnects if the package is closed.
/// </summary> /// </summary>
private void OnRunningPackageStatusChanged(object? sender, RunningPackageStatusChangedEventArgs e) private void RunningPackagesOnCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)
{
if (
e.NewItems?.OfType<KeyValuePair<Guid, RunningPackageViewModel>>().Select(x => x.Value)
is not { } newItems
)
{
if (RunningPackage != null)
{ {
RunningPackage = e.CurrentPackagePair; DisconnectFromComfy();
}
return;
}
IDisposable? onStartupComplete = null; var comfyViewModel = newItems.FirstOrDefault(
vm =>
vm.RunningPackage.InstalledPackage.Id == selectedPackageId
|| vm.RunningPackage.BasePackage is ComfyUI
);
Dispatcher.UIThread.Post(() => if (comfyViewModel is null && RunningPackage?.BasePackage is ComfyUI)
{ {
if (e.CurrentPackagePair?.BasePackage is ComfyUI package) DisconnectFromComfy();
}
else if (comfyViewModel != null && RunningPackage == null)
{ {
IsWaitingForConnection = true; IsWaitingForConnection = true;
RunningPackage = comfyViewModel.RunningPackage;
onStartupComplete = Observable onStartupComplete = Observable
.FromEventPattern<string>(package, nameof(package.StartupComplete)) .FromEventPattern<string>(
comfyViewModel.RunningPackage.BasePackage,
nameof(comfyViewModel.RunningPackage.BasePackage.StartupComplete)
)
.Take(1) .Take(1)
.Subscribe(_ => .Subscribe(_ =>
{ {
@ -146,26 +192,11 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
Logger.Trace("On package launch - starting connection"); Logger.Trace("On package launch - starting connection");
ConnectCommand.Execute(null); ConnectCommand.Execute(null);
} }
IsWaitingForConnection = false; IsWaitingForConnection = false;
}); });
}); });
} }
else
{
// Cancel any pending connection
if (ConnectCancelCommand.CanExecute(null))
{
ConnectCancelCommand.Execute(null);
}
onStartupComplete?.Dispose();
onStartupComplete = null;
IsWaitingForConnection = false;
// Disconnect
Logger.Trace("On package close - disconnecting");
DisconnectCommand.Execute(null);
}
});
} }
public override void OnLoaded() public override void OnLoaded()
@ -390,7 +421,12 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
private async Task ShowConnectionHelp() private async Task ShowConnectionHelp()
{ {
var vm = vmFactory.Get<InferenceConnectionHelpViewModel>(); var vm = vmFactory.Get<InferenceConnectionHelpViewModel>();
await vm.CreateDialog().ShowAsync(); var result = await vm.CreateDialog().ShowAsync();
if (result != ContentDialogResult.Primary)
return;
selectedPackageId = vm.SelectedPackage?.Id;
} }
/// <summary> /// <summary>

65
StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs

@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Specialized;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -36,15 +37,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.PackageManager;
[ManagedService] [ManagedService]
[Transient] [Transient]
public partial class PackageCardViewModel( public partial class PackageCardViewModel : ProgressViewModel
ILogger<PackageCardViewModel> logger,
IPackageFactory packageFactory,
INotificationService notificationService,
ISettingsManager settingsManager,
INavigationService<NewPackageManagerViewModel> navigationService,
ServiceManager<ViewModelBase> vmFactory,
RunningPackageService runningPackageService
) : ProgressViewModel
{ {
private string webUiUrl = string.Empty; private string webUiUrl = string.Empty;
@ -93,6 +86,59 @@ public partial class PackageCardViewModel(
[ObservableProperty] [ObservableProperty]
private bool showWebUiButton; private bool showWebUiButton;
private readonly ILogger<PackageCardViewModel> logger;
private readonly IPackageFactory packageFactory;
private readonly INotificationService notificationService;
private readonly ISettingsManager settingsManager;
private readonly INavigationService<NewPackageManagerViewModel> navigationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
private readonly RunningPackageService runningPackageService;
/// <inheritdoc/>
public PackageCardViewModel(
ILogger<PackageCardViewModel> logger,
IPackageFactory packageFactory,
INotificationService notificationService,
ISettingsManager settingsManager,
INavigationService<NewPackageManagerViewModel> navigationService,
ServiceManager<ViewModelBase> vmFactory,
RunningPackageService runningPackageService
)
{
this.logger = logger;
this.packageFactory = packageFactory;
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.navigationService = navigationService;
this.vmFactory = vmFactory;
this.runningPackageService = runningPackageService;
runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged;
}
private void RunningPackagesOnCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)
{
if (
e.NewItems?.OfType<KeyValuePair<Guid, RunningPackageViewModel>>().Select(x => x.Value)
is not { } newItems
)
return;
var runningViewModel = newItems.FirstOrDefault(
x => x.RunningPackage.InstalledPackage.Id == Package?.Id
);
if (runningViewModel is not null)
{
IsRunning = true;
runningViewModel.RunningPackage.BasePackage.Exited += BasePackageOnExited;
runningViewModel.RunningPackage.BasePackage.StartupComplete += RunningPackageOnStartupComplete;
}
else if (runningViewModel is null && IsRunning)
{
IsRunning = false;
}
}
partial void OnPackageChanged(InstalledPackage? value) partial void OnPackageChanged(InstalledPackage? value)
{ {
if (string.IsNullOrWhiteSpace(value?.PackageName)) if (string.IsNullOrWhiteSpace(value?.PackageName))
@ -232,7 +278,6 @@ public partial class PackageCardViewModel(
private void BasePackageOnExited(object? sender, int exitCode) private void BasePackageOnExited(object? sender, int exitCode)
{ {
EventManager.Instance.OnRunningPackageStatusChanged(null);
Dispatcher Dispatcher
.UIThread.InvokeAsync(async () => .UIThread.InvokeAsync(async () =>
{ {

19
StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs

@ -1,7 +1,5 @@
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Threading;
using AvaloniaEdit;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls; using FluentAvalonia.UI.Controls;
@ -19,7 +17,7 @@ using SymbolIconSource = FluentIcons.Avalonia.Fluent.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels; namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(ConsoleOutputPage))] [View(typeof(ConsoleOutputPage))]
public partial class RunningPackageViewModel : PageViewModelBase public partial class RunningPackageViewModel : PageViewModelBase, IDisposable, IAsyncDisposable
{ {
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private string webUiUrl = string.Empty; private string webUiUrl = string.Empty;
@ -51,7 +49,7 @@ public partial class RunningPackageViewModel : PageViewModelBase
Console = console; Console = console;
Console.Document.LineCountChanged += DocumentOnLineCountChanged; Console.Document.LineCountChanged += DocumentOnLineCountChanged;
RunningPackage.BasePackage.StartupComplete += BasePackageOnStartupComplete; RunningPackage.BasePackage.StartupComplete += BasePackageOnStartupComplete;
runningPackage.BasePackage.Exited += BasePackageOnExited; RunningPackage.BasePackage.Exited += BasePackageOnExited;
settingsManager.RelayPropertyFor( settingsManager.RelayPropertyFor(
this, this,
@ -65,6 +63,9 @@ public partial class RunningPackageViewModel : PageViewModelBase
{ {
IsRunning = false; IsRunning = false;
ShowWebUiButton = false; ShowWebUiButton = false;
Console.Document.LineCountChanged -= DocumentOnLineCountChanged;
RunningPackage.BasePackage.StartupComplete -= BasePackageOnStartupComplete;
RunningPackage.BasePackage.Exited -= BasePackageOnExited;
} }
private void BasePackageOnStartupComplete(object? sender, string url) private void BasePackageOnStartupComplete(object? sender, string url)
@ -123,4 +124,14 @@ public partial class RunningPackageViewModel : PageViewModelBase
} }
} }
} }
public void Dispose()
{
Console.Dispose();
}
public async ValueTask DisposeAsync()
{
await Console.DisposeAsync();
}
} }

45
StabilityMatrix.Core/Helper/Factory/PackageFactory.cs

@ -2,6 +2,7 @@
using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Helper.Factory; namespace StabilityMatrix.Core.Helper.Factory;
@ -13,6 +14,7 @@ public class PackageFactory : IPackageFactory
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly IDownloadService downloadService; private readonly IDownloadService downloadService;
private readonly IPrerequisiteHelper prerequisiteHelper; private readonly IPrerequisiteHelper prerequisiteHelper;
private readonly IPyRunner pyRunner;
/// <summary> /// <summary>
/// Mapping of package.Name to package /// Mapping of package.Name to package
@ -24,13 +26,15 @@ public class PackageFactory : IPackageFactory
IGithubApiCache githubApiCache, IGithubApiCache githubApiCache,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IDownloadService downloadService, IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner
) )
{ {
this.githubApiCache = githubApiCache; this.githubApiCache = githubApiCache;
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.downloadService = downloadService; this.downloadService = downloadService;
this.prerequisiteHelper = prerequisiteHelper; this.prerequisiteHelper = prerequisiteHelper;
this.pyRunner = pyRunner;
this.basePackages = basePackages.ToDictionary(x => x.Name); this.basePackages = basePackages.ToDictionary(x => x.Name);
} }
@ -44,7 +48,44 @@ public class PackageFactory : IPackageFactory
=> new A3WebUI(githubApiCache, settingsManager, downloadService, prerequisiteHelper), => new A3WebUI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"Fooocus-ControlNet-SDXL" "Fooocus-ControlNet-SDXL"
=> new FocusControlNet(githubApiCache, settingsManager, downloadService, prerequisiteHelper), => new FocusControlNet(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
_ => throw new ArgumentOutOfRangeException() "Fooocus-MRE"
=> new FooocusMre(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"InvokeAI" => new InvokeAI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"kohya_ss"
=> new KohyaSs(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper,
pyRunner
),
"OneTrainer"
=> new OneTrainer(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"RuinedFooocus"
=> new RuinedFooocus(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui-forge"
=> new SDWebForge(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui-directml"
=> new StableDiffusionDirectMl(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper
),
"stable-diffusion-webui-ux"
=> new StableDiffusionUx(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper
),
"StableSwarmUI"
=> new StableSwarm(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"automatic"
=> new VladAutomatic(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"voltaML-fast-stable-diffusion"
=> new VoltaML(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
_ => throw new ArgumentOutOfRangeException(nameof(installedPackage))
}; };
} }

Loading…
Cancel
Save