diff --git a/StabilityMatrix/App.xaml.cs b/StabilityMatrix/App.xaml.cs index 70a9c2d6..b31e852f 100644 --- a/StabilityMatrix/App.xaml.cs +++ b/StabilityMatrix/App.xaml.cs @@ -39,6 +39,7 @@ namespace StabilityMatrix var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); diff --git a/StabilityMatrix/Services/PageContentDialogService.cs b/StabilityMatrix/Services/PageContentDialogService.cs new file mode 100644 index 00000000..24c79857 --- /dev/null +++ b/StabilityMatrix/Services/PageContentDialogService.cs @@ -0,0 +1,8 @@ +using Wpf.Ui.Services; + +namespace StabilityMatrix.Services; + +public class PageContentDialogService : ContentDialogService +{ + +} diff --git a/StabilityMatrix/TextToImagePage.xaml b/StabilityMatrix/TextToImagePage.xaml index a0414784..e7e5da2a 100644 --- a/StabilityMatrix/TextToImagePage.xaml +++ b/StabilityMatrix/TextToImagePage.xaml @@ -1,6 +1,7 @@  - - - - - + + + + + + - - - - - - + + + + + + - - - - - - - - + + + + + + - - - - - - - - - + - - + + + + + + + + + - + Orientation="Vertical" + VerticalAlignment="Top"> + + + + + - - - + + - - + + + + diff --git a/StabilityMatrix/TextToImagePage.xaml.cs b/StabilityMatrix/TextToImagePage.xaml.cs index f3f1513d..56d4e944 100644 --- a/StabilityMatrix/TextToImagePage.xaml.cs +++ b/StabilityMatrix/TextToImagePage.xaml.cs @@ -1,4 +1,6 @@ -using System.Windows.Controls; +using System.Windows; +using System.Windows.Controls; +using StabilityMatrix.Services; using StabilityMatrix.ViewModels; using Wpf.Ui.Controls.AutoSuggestBoxControl; @@ -8,10 +10,11 @@ public sealed partial class TextToImagePage : Page { private TextToImageViewModel ViewModel => (TextToImageViewModel) DataContext; - public TextToImagePage(TextToImageViewModel viewModel) + public TextToImagePage(TextToImageViewModel viewModel, PageContentDialogService pageContentDialogService) { InitializeComponent(); DataContext = viewModel; + pageContentDialogService.SetContentPresenter(PageContentDialog); } private void PositivePromptBox_OnQuerySubmitted(AutoSuggestBox sender, AutoSuggestBoxQuerySubmittedEventArgs args) @@ -49,4 +52,9 @@ public sealed partial class TextToImagePage : Page ViewModel.NegativePromptText = fullText; } } + + private async void TextToImagePage_OnLoaded(object sender, RoutedEventArgs e) + { + await ViewModel.OnLoaded(); + } } diff --git a/StabilityMatrix/ViewModels/TextToImageViewModel.cs b/StabilityMatrix/ViewModels/TextToImageViewModel.cs index 9c81bbfd..1c03ab0f 100644 --- a/StabilityMatrix/ViewModels/TextToImageViewModel.cs +++ b/StabilityMatrix/ViewModels/TextToImageViewModel.cs @@ -9,6 +9,8 @@ using Microsoft.Extensions.Logging; using StabilityMatrix.Api; using StabilityMatrix.Helper; using StabilityMatrix.Models.Api; +using StabilityMatrix.Services; +using Wpf.Ui.Contracts; namespace StabilityMatrix.ViewModels; @@ -16,11 +18,16 @@ public partial class TextToImageViewModel : ObservableObject { private readonly ILogger logger; private readonly IA3WebApi a3WebApi; + private readonly IDialogErrorHandler dialogErrorHandler; + private readonly PageContentDialogService pageContentDialogService; private AsyncDispatcherTimer? progressQueryTimer; [ObservableProperty] private bool isGenerating; - + + [ObservableProperty] + private bool connectionFailed; + [ObservableProperty] [NotifyPropertyChangedFor(nameof(ProgressRingVisibility))] [NotifyPropertyChangedFor(nameof(ImagePreviewVisibility))] @@ -48,14 +55,66 @@ public partial class TextToImageViewModel : ObservableObject public Visibility ProgressBarVisibility => ProgressValue > 0 ? Visibility.Visible : Visibility.Collapsed; - public TextToImageViewModel(IA3WebApi a3WebApi, ILogger logger) + public TextToImageViewModel(IA3WebApi a3WebApi, ILogger logger, IDialogErrorHandler dialogErrorHandler, PageContentDialogService pageContentDialogService) { this.logger = logger; this.a3WebApi = a3WebApi; + this.dialogErrorHandler = dialogErrorHandler; + this.pageContentDialogService = pageContentDialogService; positivePromptText = "Positive"; negativePromptText = "Negative"; generationSteps = 10; } + + public async Task OnLoaded() + { + if (ConnectionFailed) + { + await PromptRetryConnection(); + } + else + { + await CheckConnection(); + } + } + + // Checks connection, if unsuccessful, shows a content dialog to retry + private async Task CheckConnection() + { + try + { + await a3WebApi.GetPing(); + ConnectionFailed = false; + } + catch (Exception e) + { + // On error, show a content dialog to retry + ConnectionFailed = true; + logger.LogWarning("Ping response failed: {EMessage}", e.Message); + var dialog = pageContentDialogService.CreateDialog(); + dialog.Title = "Connection failed"; + dialog.Content = "Please check the server is running with the --api launch option enabled."; + dialog.CloseButtonText = "Retry"; + dialog.IsPrimaryButtonEnabled = false; + dialog.IsSecondaryButtonEnabled = false; + await dialog.ShowAsync(); + // Retry + await CheckConnection(); + } + } + + private async Task PromptRetryConnection() + { + var dialog = pageContentDialogService.CreateDialog(); + dialog.Title = "Connection failed"; + dialog.Content = "Please check the server is running with the --api launch option enabled."; + dialog.CloseButtonText = "Retry"; + dialog.IsPrimaryButtonEnabled = false; + dialog.IsSecondaryButtonEnabled = false; + await dialog.ShowAsync(); + // Retry + await CheckConnection(); + } private void StartProgressTracking(TimeSpan? interval = null) { @@ -78,7 +137,15 @@ public partial class TextToImageViewModel : ObservableObject private async Task OnProgressTrackingTick() { var request = new ProgressRequest(); - var response = await a3WebApi.GetProgress(request); + var task = a3WebApi.GetProgress(request); + var responseResult = await dialogErrorHandler.TryAsync(task, "Failed to get progress"); + if (!responseResult.IsSuccessful || responseResult.Result == null) + { + StopProgressTracking(); + return; + } + + var response = responseResult.Result; var progress = response.Progress; logger.LogInformation("Image Progress: {ResponseProgress}, ETA: {ResponseEtaRelative} s", response.Progress, response.EtaRelative); if (Math.Abs(progress - 1.0) < 0.01) @@ -133,11 +200,13 @@ public partial class TextToImageViewModel : ObservableObject // Progress track while waiting for response StartProgressTracking(); - var response = await task; + var response = await dialogErrorHandler.TryAsync(task, "Failed to get a response from the server"); StopProgressTracking(); + if (!response.IsSuccessful || response.Result == null) return; + // Decode base64 image - var result = response.Images[0]; + var result = response.Result.Images[0]; var bitmap = Base64ToBitmap(result); ImagePreview = bitmap;