Browse Source

fix inference launch connection stuff to work with multi-package

pull/629/head
JT 9 months ago
parent
commit
45f48ce36f
  1. 4
      StabilityMatrix.Avalonia/Services/RunningPackageService.cs
  2. 15
      StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs
  3. 84
      StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
  4. 65
      StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
  5. 19
      StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs
  6. 45
      StabilityMatrix.Core/Helper/Factory/PackageFactory.cs

4
StabilityMatrix.Avalonia/Services/RunningPackageService.cs

@ -1,6 +1,5 @@
using System;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Linq;
using System.Threading.Tasks;
using Avalonia.Controls.Notifications;
@ -10,7 +9,6 @@ using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.ViewModels;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
@ -114,8 +112,6 @@ public partial class RunningPackageService(
await basePackage.RunPackage(packagePath, command, userArgsString, o => console.Post(o));
var runningPackage = new PackagePair(installedPackage, basePackage);
EventManager.Instance.OnRunningPackageStatusChanged(runningPackage);
var viewModel = new RunningPackageViewModel(
settingsManager,
notificationService,

15
StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs

@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading.Tasks;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
@ -30,6 +31,7 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa
private readonly ISettingsManager settingsManager;
private readonly INavigationService<MainWindowViewModel> navigationService;
private readonly IPackageFactory packageFactory;
private readonly RunningPackageService runningPackageService;
[ObservableProperty]
private string title = "Hello";
@ -58,12 +60,14 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa
public InferenceConnectionHelpViewModel(
ISettingsManager settingsManager,
INavigationService<MainWindowViewModel> navigationService,
IPackageFactory packageFactory
IPackageFactory packageFactory,
RunningPackageService runningPackageService
)
{
this.settingsManager = settingsManager;
this.navigationService = navigationService;
this.packageFactory = packageFactory;
this.runningPackageService = runningPackageService;
// Get comfy type installed packages
var comfyPackages = this.settingsManager.Settings.InstalledPackages.Where(
@ -122,14 +126,11 @@ public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBa
/// Request launch of the selected package
/// </summary>
[RelayCommand]
private void LaunchSelectedPackage()
private async Task LaunchSelectedPackage()
{
if (SelectedPackage?.Id is { } id)
if (SelectedPackage is not null)
{
Dispatcher.UIThread.Post(() =>
{
EventManager.Instance.OnPackageLaunchRequested(id);
});
await runningPackageService.StartPackage(SelectedPackage);
}
}

84
StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs

@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Collections.Specialized;
using System.Linq;
using System.Reactive.Linq;
using System.Text.Json;
@ -51,6 +53,8 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
private readonly ServiceManager<ViewModelBase> vmFactory;
private readonly IModelIndexService modelIndexService;
private readonly ILiteDbContext liteDbContext;
private readonly RunningPackageService runningPackageService;
private Guid? selectedPackageId;
public override string Title => "Inference";
public override IconSource IconSource =>
@ -86,6 +90,8 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
public bool IsComfyRunning => RunningPackage?.BasePackage is ComfyUI;
private IDisposable? onStartupComplete;
public InferenceViewModel(
ServiceManager<ViewModelBase> vmFactory,
INotificationService notificationService,
@ -93,6 +99,7 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
ISettingsManager settingsManager,
IModelIndexService modelIndexService,
ILiteDbContext liteDbContext,
RunningPackageService runningPackageService,
SharedState sharedState
)
{
@ -101,12 +108,13 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
this.settingsManager = settingsManager;
this.modelIndexService = modelIndexService;
this.liteDbContext = liteDbContext;
this.runningPackageService = runningPackageService;
ClientManager = inferenceClientManager;
SharedState = sharedState;
// Keep RunningPackage updated with the current package pair
EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged;
runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged;
// "Send to Inference"
EventManager.Instance.InferenceTextToImageRequested += OnInferenceTextToImageRequested;
@ -118,24 +126,62 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService);
}
private void DisconnectFromComfy()
{
RunningPackage = null;
// Cancel any pending connection
if (ConnectCancelCommand.CanExecute(null))
{
ConnectCancelCommand.Execute(null);
}
onStartupComplete?.Dispose();
onStartupComplete = null;
IsWaitingForConnection = false;
// Disconnect
Logger.Trace("On package close - disconnecting");
DisconnectCommand.Execute(null);
}
/// <summary>
/// Updates the RunningPackage property when the running package changes.
/// Also starts a connection to the backend if a new ComfyUI package is running.
/// And disconnects if the package is closed.
/// </summary>
private void OnRunningPackageStatusChanged(object? sender, RunningPackageStatusChangedEventArgs e)
private void RunningPackagesOnCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)
{
if (
e.NewItems?.OfType<KeyValuePair<Guid, RunningPackageViewModel>>().Select(x => x.Value)
is not { } newItems
)
{
if (RunningPackage != null)
{
RunningPackage = e.CurrentPackagePair;
DisconnectFromComfy();
}
return;
}
IDisposable? onStartupComplete = null;
var comfyViewModel = newItems.FirstOrDefault(
vm =>
vm.RunningPackage.InstalledPackage.Id == selectedPackageId
|| vm.RunningPackage.BasePackage is ComfyUI
);
Dispatcher.UIThread.Post(() =>
if (comfyViewModel is null && RunningPackage?.BasePackage is ComfyUI)
{
if (e.CurrentPackagePair?.BasePackage is ComfyUI package)
DisconnectFromComfy();
}
else if (comfyViewModel != null && RunningPackage == null)
{
IsWaitingForConnection = true;
RunningPackage = comfyViewModel.RunningPackage;
onStartupComplete = Observable
.FromEventPattern<string>(package, nameof(package.StartupComplete))
.FromEventPattern<string>(
comfyViewModel.RunningPackage.BasePackage,
nameof(comfyViewModel.RunningPackage.BasePackage.StartupComplete)
)
.Take(1)
.Subscribe(_ =>
{
@ -146,26 +192,11 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
Logger.Trace("On package launch - starting connection");
ConnectCommand.Execute(null);
}
IsWaitingForConnection = false;
});
});
}
else
{
// Cancel any pending connection
if (ConnectCancelCommand.CanExecute(null))
{
ConnectCancelCommand.Execute(null);
}
onStartupComplete?.Dispose();
onStartupComplete = null;
IsWaitingForConnection = false;
// Disconnect
Logger.Trace("On package close - disconnecting");
DisconnectCommand.Execute(null);
}
});
}
public override void OnLoaded()
@ -390,7 +421,12 @@ public partial class InferenceViewModel : PageViewModelBase, IAsyncDisposable
private async Task ShowConnectionHelp()
{
var vm = vmFactory.Get<InferenceConnectionHelpViewModel>();
await vm.CreateDialog().ShowAsync();
var result = await vm.CreateDialog().ShowAsync();
if (result != ContentDialogResult.Primary)
return;
selectedPackageId = vm.SelectedPackage?.Id;
}
/// <summary>

65
StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs

@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@ -36,15 +37,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.PackageManager;
[ManagedService]
[Transient]
public partial class PackageCardViewModel(
ILogger<PackageCardViewModel> logger,
IPackageFactory packageFactory,
INotificationService notificationService,
ISettingsManager settingsManager,
INavigationService<NewPackageManagerViewModel> navigationService,
ServiceManager<ViewModelBase> vmFactory,
RunningPackageService runningPackageService
) : ProgressViewModel
public partial class PackageCardViewModel : ProgressViewModel
{
private string webUiUrl = string.Empty;
@ -93,6 +86,59 @@ public partial class PackageCardViewModel(
[ObservableProperty]
private bool showWebUiButton;
private readonly ILogger<PackageCardViewModel> logger;
private readonly IPackageFactory packageFactory;
private readonly INotificationService notificationService;
private readonly ISettingsManager settingsManager;
private readonly INavigationService<NewPackageManagerViewModel> navigationService;
private readonly ServiceManager<ViewModelBase> vmFactory;
private readonly RunningPackageService runningPackageService;
/// <inheritdoc/>
public PackageCardViewModel(
ILogger<PackageCardViewModel> logger,
IPackageFactory packageFactory,
INotificationService notificationService,
ISettingsManager settingsManager,
INavigationService<NewPackageManagerViewModel> navigationService,
ServiceManager<ViewModelBase> vmFactory,
RunningPackageService runningPackageService
)
{
this.logger = logger;
this.packageFactory = packageFactory;
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.navigationService = navigationService;
this.vmFactory = vmFactory;
this.runningPackageService = runningPackageService;
runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged;
}
private void RunningPackagesOnCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)
{
if (
e.NewItems?.OfType<KeyValuePair<Guid, RunningPackageViewModel>>().Select(x => x.Value)
is not { } newItems
)
return;
var runningViewModel = newItems.FirstOrDefault(
x => x.RunningPackage.InstalledPackage.Id == Package?.Id
);
if (runningViewModel is not null)
{
IsRunning = true;
runningViewModel.RunningPackage.BasePackage.Exited += BasePackageOnExited;
runningViewModel.RunningPackage.BasePackage.StartupComplete += RunningPackageOnStartupComplete;
}
else if (runningViewModel is null && IsRunning)
{
IsRunning = false;
}
}
partial void OnPackageChanged(InstalledPackage? value)
{
if (string.IsNullOrWhiteSpace(value?.PackageName))
@ -232,7 +278,6 @@ public partial class PackageCardViewModel(
private void BasePackageOnExited(object? sender, int exitCode)
{
EventManager.Instance.OnRunningPackageStatusChanged(null);
Dispatcher
.UIThread.InvokeAsync(async () =>
{

19
StabilityMatrix.Avalonia/ViewModels/RunningPackageViewModel.cs

@ -1,7 +1,5 @@
using System;
using System.Threading.Tasks;
using Avalonia.Threading;
using AvaloniaEdit;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
@ -19,7 +17,7 @@ using SymbolIconSource = FluentIcons.Avalonia.Fluent.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(ConsoleOutputPage))]
public partial class RunningPackageViewModel : PageViewModelBase
public partial class RunningPackageViewModel : PageViewModelBase, IDisposable, IAsyncDisposable
{
private readonly INotificationService notificationService;
private string webUiUrl = string.Empty;
@ -51,7 +49,7 @@ public partial class RunningPackageViewModel : PageViewModelBase
Console = console;
Console.Document.LineCountChanged += DocumentOnLineCountChanged;
RunningPackage.BasePackage.StartupComplete += BasePackageOnStartupComplete;
runningPackage.BasePackage.Exited += BasePackageOnExited;
RunningPackage.BasePackage.Exited += BasePackageOnExited;
settingsManager.RelayPropertyFor(
this,
@ -65,6 +63,9 @@ public partial class RunningPackageViewModel : PageViewModelBase
{
IsRunning = false;
ShowWebUiButton = false;
Console.Document.LineCountChanged -= DocumentOnLineCountChanged;
RunningPackage.BasePackage.StartupComplete -= BasePackageOnStartupComplete;
RunningPackage.BasePackage.Exited -= BasePackageOnExited;
}
private void BasePackageOnStartupComplete(object? sender, string url)
@ -123,4 +124,14 @@ public partial class RunningPackageViewModel : PageViewModelBase
}
}
}
public void Dispose()
{
Console.Dispose();
}
public async ValueTask DisposeAsync()
{
await Console.DisposeAsync();
}
}

45
StabilityMatrix.Core/Helper/Factory/PackageFactory.cs

@ -2,6 +2,7 @@
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Helper.Factory;
@ -13,6 +14,7 @@ public class PackageFactory : IPackageFactory
private readonly ISettingsManager settingsManager;
private readonly IDownloadService downloadService;
private readonly IPrerequisiteHelper prerequisiteHelper;
private readonly IPyRunner pyRunner;
/// <summary>
/// Mapping of package.Name to package
@ -24,13 +26,15 @@ public class PackageFactory : IPackageFactory
IGithubApiCache githubApiCache,
ISettingsManager settingsManager,
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper
IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner
)
{
this.githubApiCache = githubApiCache;
this.settingsManager = settingsManager;
this.downloadService = downloadService;
this.prerequisiteHelper = prerequisiteHelper;
this.pyRunner = pyRunner;
this.basePackages = basePackages.ToDictionary(x => x.Name);
}
@ -44,7 +48,44 @@ public class PackageFactory : IPackageFactory
=> new A3WebUI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"Fooocus-ControlNet-SDXL"
=> new FocusControlNet(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
_ => throw new ArgumentOutOfRangeException()
"Fooocus-MRE"
=> new FooocusMre(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"InvokeAI" => new InvokeAI(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"kohya_ss"
=> new KohyaSs(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper,
pyRunner
),
"OneTrainer"
=> new OneTrainer(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"RuinedFooocus"
=> new RuinedFooocus(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui-forge"
=> new SDWebForge(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"stable-diffusion-webui-directml"
=> new StableDiffusionDirectMl(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper
),
"stable-diffusion-webui-ux"
=> new StableDiffusionUx(
githubApiCache,
settingsManager,
downloadService,
prerequisiteHelper
),
"StableSwarmUI"
=> new StableSwarm(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"automatic"
=> new VladAutomatic(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
"voltaML-fast-stable-diffusion"
=> new VoltaML(githubApiCache, settingsManager, downloadService, prerequisiteHelper),
_ => throw new ArgumentOutOfRangeException(nameof(installedPackage))
};
}

Loading…
Cancel
Save