using System.Text.Json; using System.Threading.Tasks; using Avalonia.Collections; using Avalonia.Controls.Notifications; using Avalonia.Platform.Storage; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using NLog; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Services; using Symbol = FluentIcons.Common.Symbol; using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(InferencePage))] public partial class InferenceViewModel : PageViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; private readonly ISettingsManager settingsManager; private readonly ServiceManager vmFactory; private readonly IApiFactory apiFactory; public override string Title => "Inference"; public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.AppGeneric, IsFilled = true }; public RefreshBadgeViewModel ConnectionBadge { get; } = new() { State = ProgressState.Failed, FailToolTipText = "Not connected", FailIcon = FluentAvalonia.UI.Controls.Symbol.Refresh, SuccessToolTipText = "Connected", }; public IInferenceClientManager ClientManager { get; } public AvaloniaList Tabs { get; } = new(); [ObservableProperty] private LoadableViewModelBase? selectedTab; [ObservableProperty] private PackagePair? runningPackage; public InferenceViewModel( ServiceManager vmFactory, IApiFactory apiFactory, INotificationService notificationService, IInferenceClientManager inferenceClientManager, ISettingsManager settingsManager ) { this.vmFactory = vmFactory; this.apiFactory = apiFactory; this.notificationService = notificationService; this.settingsManager = settingsManager; ClientManager = inferenceClientManager; // Keep RunningPackage updated with the current package pair EventManager.Instance.RunningPackageStatusChanged += (_, args) => { RunningPackage = args.CurrentPackagePair; }; } public override void OnLoaded() { if (Tabs.Count == 0) { AddTab(); } // Select first tab if none is selected if (SelectedTab is null && Tabs.Count > 0) { SelectedTab = Tabs[0]; } base.OnLoaded(); } /// /// When the + button on the tab control is clicked, add a new tab. /// [RelayCommand] private void AddTab() { Tabs.Add(vmFactory.Get()); } /// /// When the close button on the tab is clicked, remove the tab. /// public void OnTabCloseRequested(TabViewTabCloseRequestedEventArgs e) { if (e.Item is LoadableViewModelBase vm) { Tabs.Remove(vm); } } /// /// Connect to the inference server. /// [RelayCommand] private async Task Connect() { if (ClientManager.IsConnected) { notificationService.Show("Already connected", "ComfyUI backend is already connected"); return; } // TODO: make address configurable if (RunningPackage is not null) { await notificationService.TryAsync( ClientManager.ConnectAsync(RunningPackage), "Could not connect to backend" ); } } /// /// Disconnect from the inference server. /// [RelayCommand] private async Task Disconnect() { if (!ClientManager.IsConnected) { notificationService.Show("Not connected", "ComfyUI backend is not connected"); return; } await notificationService.TryAsync( ClientManager.CloseAsync(), "Could not disconnect from ComfyUI backend" ); } /// /// Menu "Save As" command. /// [RelayCommand] private async Task MenuSaveAs() { var currentTab = SelectedTab; if (currentTab == null) { Logger.Trace("MenuSaveAs: currentTab is null"); return; } var document = InferenceProjectDocument.FromLoadable(currentTab); // Prompt for save file dialog var provider = App.StorageProvider; var projectDir = new DirectoryPath(settingsManager.LibraryDir, "Projects"); projectDir.Create(); var startDir = await provider.TryGetFolderFromPathAsync(projectDir); var result = await provider.SaveFilePickerAsync( new FilePickerSaveOptions { Title = "Save As", SuggestedFileName = "Untitled", FileTypeChoices = new FilePickerFileType[] { new("StabilityMatrix Project") { Patterns = new[] { "*.smproj" }, MimeTypes = new[] { "application/json" }, } }, SuggestedStartLocation = startDir, DefaultExtension = ".smproj", ShowOverwritePrompt = true, } ); if (result is null) { Logger.Trace("MenuSaveAs: user cancelled"); return; } // Save to file await using var stream = await result.OpenWriteAsync(); await JsonSerializer.SerializeAsync( stream, document, new JsonSerializerOptions { WriteIndented = true, } ); notificationService.Show( "Saved", $"Saved project to {result.Name}", NotificationType.Success ); } /// /// Menu "Open Project" command. /// [RelayCommand] private async Task MenuOpenProject() { // Prompt for open file dialog var provider = App.StorageProvider; var projectDir = new DirectoryPath(settingsManager.LibraryDir, "Projects"); projectDir.Create(); var startDir = await provider.TryGetFolderFromPathAsync(projectDir); var results = await provider.OpenFilePickerAsync( new FilePickerOpenOptions { Title = "Open Project File", FileTypeFilter = new FilePickerFileType[] { new("StabilityMatrix Project") { Patterns = new[] { "*.smproj" }, MimeTypes = new[] { "application/json" }, } }, SuggestedStartLocation = startDir, } ); if (results.Count == 0) { Logger.Trace("MenuOpenProject: No files selected"); return; } // Load from file var file = results[0]; await using var stream = await file.OpenReadAsync(); var document = await JsonSerializer.DeserializeAsync(stream); if (document is null) { Logger.Warn("MenuOpenProject: Deserialize project file returned null"); return; } LoadableViewModelBase? vm = null; if (document.ProjectType is InferenceProjectType.TextToImage && document.State is not null) { var textToImage = vmFactory.Get(); textToImage.LoadStateFromJsonObject(document.State); vm = textToImage; } if (vm == null) { Logger.Warn("MenuOpenProject: Unknown project type"); return; } Tabs.Add(vm); } }