You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
654 lines
20 KiB
654 lines
20 KiB
using System; |
|
using System.Collections.Immutable; |
|
using System.Collections.ObjectModel; |
|
using System.Linq; |
|
using System.Reactive.Linq; |
|
using System.Text.Json; |
|
using System.Threading; |
|
using System.Threading.Tasks; |
|
using AsyncAwaitBestPractices; |
|
using Avalonia.Controls; |
|
using Avalonia.Controls.Notifications; |
|
using Avalonia.Controls.Shapes; |
|
using Avalonia.Platform.Storage; |
|
using Avalonia.Threading; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using CommunityToolkit.Mvvm.Input; |
|
using FluentAvalonia.UI.Controls; |
|
using NLog; |
|
using StabilityMatrix.Avalonia.Extensions; |
|
using StabilityMatrix.Avalonia.Models; |
|
using StabilityMatrix.Avalonia.Services; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Avalonia.ViewModels.Dialogs; |
|
using StabilityMatrix.Avalonia.ViewModels.Inference; |
|
using StabilityMatrix.Avalonia.Views; |
|
using StabilityMatrix.Core.Api; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Database; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Helper; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.Database; |
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
using StabilityMatrix.Core.Models.Packages; |
|
using StabilityMatrix.Core.Models.Progress; |
|
using StabilityMatrix.Core.Services; |
|
using InferenceTabViewModelBase = StabilityMatrix.Avalonia.ViewModels.Base.InferenceTabViewModelBase; |
|
using Path = System.IO.Path; |
|
using Symbol = FluentIcons.Common.Symbol; |
|
using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels; |
|
|
|
[Preload] |
|
[View(typeof(InferencePage))] |
|
public partial class InferenceViewModel : PageViewModelBase |
|
{ |
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
private readonly INotificationService notificationService; |
|
|
|
private readonly ISettingsManager settingsManager; |
|
private readonly ServiceManager<ViewModelBase> vmFactory; |
|
private readonly IModelIndexService modelIndexService; |
|
private readonly ILiteDbContext liteDbContext; |
|
|
|
private bool isFirstLoadComplete; |
|
|
|
public override string Title => "Inference"; |
|
public override IconSource IconSource => |
|
new SymbolIconSource { Symbol = Symbol.AppGeneric, IsFilled = true }; |
|
|
|
public RefreshBadgeViewModel ConnectionBadge { get; } = |
|
new() |
|
{ |
|
State = ProgressState.Failed, |
|
FailToolTipText = "Not connected", |
|
FailIcon = FluentAvalonia.UI.Controls.Symbol.Refresh, |
|
SuccessToolTipText = "Connected", |
|
}; |
|
|
|
public IInferenceClientManager ClientManager { get; } |
|
|
|
public SharedState SharedState { get; } |
|
|
|
public ObservableCollection<InferenceTabViewModelBase> Tabs { get; } = new(); |
|
|
|
[ObservableProperty] |
|
private InferenceTabViewModelBase? selectedTab; |
|
|
|
[ObservableProperty] |
|
private int selectedTabIndex; |
|
|
|
[ObservableProperty] |
|
private bool isWaitingForConnection; |
|
|
|
[ObservableProperty] |
|
[NotifyPropertyChangedFor(nameof(IsComfyRunning))] |
|
private PackagePair? runningPackage; |
|
|
|
public bool IsComfyRunning => RunningPackage?.BasePackage is ComfyUI; |
|
|
|
public InferenceViewModel( |
|
ServiceManager<ViewModelBase> vmFactory, |
|
INotificationService notificationService, |
|
IInferenceClientManager inferenceClientManager, |
|
ISettingsManager settingsManager, |
|
IModelIndexService modelIndexService, |
|
ILiteDbContext liteDbContext, |
|
SharedState sharedState |
|
) |
|
{ |
|
this.vmFactory = vmFactory; |
|
this.notificationService = notificationService; |
|
this.settingsManager = settingsManager; |
|
this.modelIndexService = modelIndexService; |
|
this.liteDbContext = liteDbContext; |
|
|
|
ClientManager = inferenceClientManager; |
|
SharedState = sharedState; |
|
|
|
// Keep RunningPackage updated with the current package pair |
|
EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged; |
|
|
|
MenuSaveAsCommand.WithConditionalNotificationErrorHandler(notificationService); |
|
MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService); |
|
} |
|
|
|
/// <summary> |
|
/// Updates the RunningPackage property when the running package changes. |
|
/// Also starts a connection to the backend if a new ComfyUI package is running. |
|
/// And disconnects if the package is closed. |
|
/// </summary> |
|
private void OnRunningPackageStatusChanged( |
|
object? sender, |
|
RunningPackageStatusChangedEventArgs e |
|
) |
|
{ |
|
RunningPackage = e.CurrentPackagePair; |
|
|
|
IDisposable? onStartupComplete = null; |
|
|
|
Dispatcher.UIThread.Post(() => |
|
{ |
|
if (e.CurrentPackagePair?.BasePackage is ComfyUI package) |
|
{ |
|
IsWaitingForConnection = true; |
|
onStartupComplete = Observable |
|
.FromEventPattern<string>(package, nameof(package.StartupComplete)) |
|
.Take(1) |
|
.Subscribe(_ => |
|
{ |
|
Dispatcher.UIThread.Post(() => |
|
{ |
|
if (ConnectCommand.CanExecute(null)) |
|
{ |
|
Logger.Trace("On package launch - starting connection"); |
|
ConnectCommand.Execute(null); |
|
} |
|
IsWaitingForConnection = false; |
|
}); |
|
}); |
|
} |
|
else |
|
{ |
|
// Cancel any pending connection |
|
if (ConnectCancelCommand.CanExecute(null)) |
|
{ |
|
ConnectCancelCommand.Execute(null); |
|
} |
|
onStartupComplete?.Dispose(); |
|
onStartupComplete = null; |
|
IsWaitingForConnection = false; |
|
|
|
// Disconnect |
|
Logger.Trace("On package close - disconnecting"); |
|
DisconnectCommand.Execute(null); |
|
} |
|
}); |
|
} |
|
|
|
public override async Task OnLoadedAsync() |
|
{ |
|
await base.OnLoadedAsync(); |
|
|
|
if (!Design.IsDesignMode && !isFirstLoadComplete) |
|
{ |
|
isFirstLoadComplete = true; |
|
|
|
// Load any open projects |
|
var openProjects = await liteDbContext.InferenceProjects.FindAsync(p => p.IsOpen); |
|
|
|
if (openProjects is not null) |
|
{ |
|
foreach (var project in openProjects.OrderBy(p => p.CurrentTabIndex)) |
|
{ |
|
var file = new FilePath(project.FilePath); |
|
|
|
if (!file.Exists) |
|
{ |
|
// Remove from database |
|
await liteDbContext.InferenceProjects.DeleteAsync(project.Id); |
|
} |
|
|
|
try |
|
{ |
|
if (file.Exists) |
|
{ |
|
await AddTabFromFile(project.FilePath); |
|
} |
|
} |
|
catch (Exception e) |
|
{ |
|
Logger.Warn(e, "Failed to open project file {FilePath}", project.FilePath); |
|
|
|
notificationService.Show( |
|
"Failed to open project file", |
|
$"[{e.GetType().Name}] {e.Message}", |
|
NotificationType.Error |
|
); |
|
|
|
// Set not open |
|
await liteDbContext.InferenceProjects.UpdateAsync( |
|
project with |
|
{ |
|
IsOpen = false, |
|
IsSelected = false, |
|
CurrentTabIndex = -1 |
|
} |
|
); |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (Tabs.Count == 0) |
|
{ |
|
AddTab(InferenceProjectType.TextToImage); |
|
} |
|
|
|
// Start a model index update |
|
modelIndexService.BackgroundRefreshIndex(); |
|
} |
|
|
|
/// <summary> |
|
/// Update the database with current tabs |
|
/// </summary> |
|
private async Task SyncTabStatesWithDatabase() |
|
{ |
|
// Update the database with the current tabs |
|
foreach (var (i, tab) in Tabs.ToImmutableArray().Enumerate()) |
|
{ |
|
if (tab.ProjectFile is not { } projectFile) |
|
{ |
|
continue; |
|
} |
|
|
|
var projectPath = projectFile.ToString(); |
|
|
|
var entry = await liteDbContext.InferenceProjects.FindOneAsync( |
|
p => p.FilePath == projectPath |
|
); |
|
|
|
// Create if not found |
|
entry ??= new InferenceProjectEntry |
|
{ |
|
Id = Guid.NewGuid(), |
|
FilePath = projectFile.ToString() |
|
}; |
|
|
|
entry.IsOpen = tab == SelectedTab; |
|
entry.CurrentTabIndex = i; |
|
|
|
Logger.Trace( |
|
"SyncTabStatesWithDatabase updated entry for tab '{Title}': {@Entry}", |
|
tab.TabTitle, |
|
entry |
|
); |
|
await liteDbContext.InferenceProjects.UpsertAsync(entry); |
|
} |
|
} |
|
|
|
/// <summary> |
|
/// Update the database with given tab |
|
/// </summary> |
|
private async Task SyncTabStateWithDatabase(InferenceTabViewModelBase tab) |
|
{ |
|
if (tab.ProjectFile is not { } projectFile) |
|
{ |
|
return; |
|
} |
|
|
|
var entry = await liteDbContext.InferenceProjects.FindOneAsync( |
|
p => p.FilePath == projectFile.ToString() |
|
); |
|
|
|
// Create if not found |
|
entry ??= new InferenceProjectEntry |
|
{ |
|
Id = Guid.NewGuid(), |
|
FilePath = projectFile.ToString() |
|
}; |
|
|
|
entry.IsOpen = tab == SelectedTab; |
|
entry.CurrentTabIndex = Tabs.IndexOf(tab); |
|
|
|
Logger.Trace( |
|
"SyncTabStatesWithDatabase updated entry for tab '{Title}': {@Entry}", |
|
tab.TabTitle, |
|
entry |
|
); |
|
await liteDbContext.InferenceProjects.UpsertAsync(entry); |
|
} |
|
|
|
/// <summary> |
|
/// When the + button on the tab control is clicked, add a new tab. |
|
/// </summary> |
|
[RelayCommand] |
|
public void AddTab(InferenceProjectType type) |
|
{ |
|
if (type.ToViewModelType() is not { } vmType) |
|
{ |
|
return; |
|
} |
|
|
|
var tab = |
|
vmFactory.Get(vmType) as InferenceTabViewModelBase |
|
?? throw new NullReferenceException($"Could not create view model of type {vmType}"); |
|
Tabs.Add(tab); |
|
|
|
// Set as new selected tab |
|
SelectedTabIndex = Tabs.Count - 1; |
|
|
|
// Update the database with the current tab |
|
SyncTabStateWithDatabase(tab).SafeFireAndForget(); |
|
} |
|
|
|
/// <summary> |
|
/// When the close button on the tab is clicked, remove the tab. |
|
/// </summary> |
|
public void OnTabCloseRequested(TabViewTabCloseRequestedEventArgs e) |
|
{ |
|
if (e.Item is not InferenceTabViewModelBase vm) |
|
{ |
|
Logger.Warn("Tab close requested for unknown item {@Item}", e); |
|
return; |
|
} |
|
|
|
Logger.Trace("Closing tab {Title}", vm.TabTitle); |
|
|
|
// Set the selected tab to the next tab if there is one, then previous, then null |
|
lock (Tabs) |
|
{ |
|
var index = Tabs.IndexOf(vm); |
|
if (index < Tabs.Count - 1) |
|
{ |
|
SelectedTabIndex = index + 1; |
|
} |
|
else if (index > 0) |
|
{ |
|
SelectedTabIndex = index - 1; |
|
} |
|
|
|
// Remove the tab |
|
Tabs.RemoveAt(index); |
|
} |
|
|
|
// Update the database with the current tab |
|
SyncTabStateWithDatabase(vm).SafeFireAndForget(); |
|
|
|
// Dispose the view model |
|
vm.Dispose(); |
|
} |
|
|
|
/// <summary> |
|
/// Show the connection help dialog. |
|
/// </summary> |
|
[RelayCommand] |
|
private async Task ShowConnectionHelp() |
|
{ |
|
var vm = vmFactory.Get<InferenceConnectionHelpViewModel>(); |
|
await vm.CreateDialog().ShowAsync(); |
|
} |
|
|
|
/// <summary> |
|
/// Connect to the inference server. |
|
/// </summary> |
|
[RelayCommand(IncludeCancelCommand = true)] |
|
private async Task Connect(CancellationToken cancellationToken = default) |
|
{ |
|
if (ClientManager.IsConnected) |
|
return; |
|
|
|
if (Design.IsDesignMode) |
|
{ |
|
await ClientManager.ConnectAsync(cancellationToken); |
|
return; |
|
} |
|
|
|
if (RunningPackage is not null) |
|
{ |
|
await notificationService.TryAsync( |
|
ClientManager.ConnectAsync(RunningPackage, cancellationToken), |
|
"Could not connect to backend" |
|
); |
|
} |
|
} |
|
|
|
/// <summary> |
|
/// Disconnect from the inference server. |
|
/// </summary> |
|
[RelayCommand] |
|
private async Task Disconnect() |
|
{ |
|
if (!ClientManager.IsConnected) |
|
return; |
|
|
|
if (Design.IsDesignMode) |
|
{ |
|
await ClientManager.CloseAsync(); |
|
return; |
|
} |
|
|
|
await notificationService.TryAsync( |
|
ClientManager.CloseAsync(), |
|
"Could not disconnect from ComfyUI backend" |
|
); |
|
} |
|
|
|
/// <summary> |
|
/// Menu "Save As" command. |
|
/// </summary> |
|
[RelayCommand(FlowExceptionsToTaskScheduler = true)] |
|
private async Task MenuSaveAs() |
|
{ |
|
var currentTab = SelectedTab; |
|
if (currentTab == null) |
|
{ |
|
Logger.Warn("MenuSaveAs: currentTab is null"); |
|
return; |
|
} |
|
|
|
// Prompt for save file dialog |
|
var provider = App.StorageProvider; |
|
|
|
var projectDir = new DirectoryPath(settingsManager.LibraryDir, "Projects"); |
|
projectDir.Create(); |
|
var startDir = await provider.TryGetFolderFromPathAsync(projectDir); |
|
|
|
var result = await provider.SaveFilePickerAsync( |
|
new FilePickerSaveOptions |
|
{ |
|
Title = "Save As", |
|
SuggestedFileName = "Untitled", |
|
FileTypeChoices = new FilePickerFileType[] |
|
{ |
|
new("StabilityMatrix Project") |
|
{ |
|
Patterns = new[] { "*.smproj" }, |
|
MimeTypes = new[] { "application/json" }, |
|
} |
|
}, |
|
SuggestedStartLocation = startDir, |
|
DefaultExtension = ".smproj", |
|
ShowOverwritePrompt = true, |
|
} |
|
); |
|
|
|
if (result is null) |
|
{ |
|
Logger.Trace("MenuSaveAs: user cancelled"); |
|
return; |
|
} |
|
|
|
var document = InferenceProjectDocument.FromLoadable(currentTab); |
|
|
|
// Save to file |
|
try |
|
{ |
|
await using var stream = await result.OpenWriteAsync(); |
|
stream.SetLength(0); // Overwrite fully |
|
|
|
await JsonSerializer.SerializeAsync( |
|
stream, |
|
document, |
|
new JsonSerializerOptions { WriteIndented = true } |
|
); |
|
} |
|
catch (Exception e) |
|
{ |
|
notificationService.ShowPersistent( |
|
"Could not save to file", |
|
$"[{e.GetType().Name}] {e.Message}", |
|
NotificationType.Error |
|
); |
|
return; |
|
} |
|
|
|
// Update project file |
|
currentTab.ProjectFile = new FilePath(result.TryGetLocalPath()!); |
|
|
|
await SyncTabStatesWithDatabase(); |
|
|
|
notificationService.Show( |
|
"Saved", |
|
$"Saved project to {result.Name}", |
|
NotificationType.Success |
|
); |
|
} |
|
|
|
/// <summary> |
|
/// Menu "Save Project" command. |
|
/// </summary> |
|
[RelayCommand(FlowExceptionsToTaskScheduler = true)] |
|
private async Task MenuSave() |
|
{ |
|
if (SelectedTab is not { } currentTab) |
|
{ |
|
Logger.Info("MenuSaveProject: currentTab is null"); |
|
return; |
|
} |
|
|
|
// If the tab has no project file, prompt for save as |
|
if (currentTab.ProjectFile is not { } projectFile) |
|
{ |
|
await MenuSaveAs(); |
|
return; |
|
} |
|
|
|
// Otherwise, save to the current project file |
|
var document = InferenceProjectDocument.FromLoadable(currentTab); |
|
|
|
// Save to file |
|
try |
|
{ |
|
await using var stream = projectFile.Info.OpenWrite(); |
|
stream.SetLength(0); // Overwrite fully |
|
|
|
await JsonSerializer.SerializeAsync( |
|
stream, |
|
document, |
|
new JsonSerializerOptions { WriteIndented = true } |
|
); |
|
} |
|
catch (Exception e) |
|
{ |
|
notificationService.ShowPersistent( |
|
"Could not save to file", |
|
$"[{e.GetType().Name}] {e.Message}", |
|
NotificationType.Error |
|
); |
|
return; |
|
} |
|
|
|
notificationService.Show( |
|
"Saved", |
|
$"Saved project to {projectFile.Name}", |
|
NotificationType.Success |
|
); |
|
} |
|
|
|
private async Task AddTabFromFile(FilePath file) |
|
{ |
|
await using var stream = file.Info.OpenRead(); |
|
|
|
var document = await JsonSerializer.DeserializeAsync<InferenceProjectDocument>(stream); |
|
if (document is null) |
|
{ |
|
throw new ApplicationException( |
|
"MenuOpenProject: Deserialize project file returned null" |
|
); |
|
} |
|
|
|
if (document.State is null) |
|
{ |
|
throw new ApplicationException("Project file does not have 'State' key"); |
|
} |
|
|
|
document.VerifyVersion(); |
|
|
|
InferenceTabViewModelBase vm; |
|
if (document.ProjectType is InferenceProjectType.TextToImage) |
|
{ |
|
// Get view model |
|
var textToImage = vmFactory.Get<InferenceTextToImageViewModel>(); |
|
// Load state |
|
textToImage.LoadStateFromJsonObject(document.State); |
|
// Set the file backing the view model |
|
textToImage.ProjectFile = file; |
|
vm = textToImage; |
|
} |
|
else |
|
{ |
|
throw new InvalidOperationException( |
|
$"Unsupported project type: {document.ProjectType}" |
|
); |
|
} |
|
|
|
Tabs.Add(vm); |
|
|
|
SelectedTab = vm; |
|
|
|
await SyncTabStatesWithDatabase(); |
|
} |
|
|
|
/// <summary> |
|
/// Menu "Open Project" command. |
|
/// </summary> |
|
[RelayCommand(FlowExceptionsToTaskScheduler = true)] |
|
private async Task MenuOpenProject() |
|
{ |
|
// Prompt for open file dialog |
|
var provider = App.StorageProvider; |
|
|
|
var projectDir = new DirectoryPath(settingsManager.LibraryDir, "Projects"); |
|
projectDir.Create(); |
|
var startDir = await provider.TryGetFolderFromPathAsync(projectDir); |
|
|
|
var results = await provider.OpenFilePickerAsync( |
|
new FilePickerOpenOptions |
|
{ |
|
Title = "Open Project File", |
|
FileTypeFilter = new FilePickerFileType[] |
|
{ |
|
new("StabilityMatrix Project") |
|
{ |
|
Patterns = new[] { "*.smproj" }, |
|
MimeTypes = new[] { "application/json" }, |
|
} |
|
}, |
|
SuggestedStartLocation = startDir, |
|
} |
|
); |
|
|
|
if (results.Count == 0) |
|
{ |
|
Logger.Trace("MenuOpenProject: No files selected"); |
|
return; |
|
} |
|
|
|
// Load from file |
|
var file = results[0].TryGetLocalPath()!; |
|
|
|
try |
|
{ |
|
await AddTabFromFile(file); |
|
} |
|
catch (NotSupportedException e) |
|
{ |
|
notificationService.ShowPersistent( |
|
$"Unsupported Project Version", |
|
$"[{Path.GetFileName(file)}] {e.Message}", |
|
NotificationType.Error |
|
); |
|
} |
|
catch (Exception e) |
|
{ |
|
notificationService.ShowPersistent( |
|
$"Failed to load Project", |
|
$"[{Path.GetFileName(file)}] {e.Message}", |
|
NotificationType.Error |
|
); |
|
} |
|
} |
|
}
|
|
|