|
|
|
@ -9,6 +9,8 @@ using Microsoft.Extensions.Logging;
|
|
|
|
|
using StabilityMatrix.Api; |
|
|
|
|
using StabilityMatrix.Helper; |
|
|
|
|
using StabilityMatrix.Models.Api; |
|
|
|
|
using StabilityMatrix.Services; |
|
|
|
|
using Wpf.Ui.Contracts; |
|
|
|
|
|
|
|
|
|
namespace StabilityMatrix.ViewModels; |
|
|
|
|
|
|
|
|
@ -16,11 +18,16 @@ public partial class TextToImageViewModel : ObservableObject
|
|
|
|
|
{ |
|
|
|
|
private readonly ILogger<TextToImageViewModel> logger; |
|
|
|
|
private readonly IA3WebApi a3WebApi; |
|
|
|
|
private readonly IDialogErrorHandler dialogErrorHandler; |
|
|
|
|
private readonly PageContentDialogService pageContentDialogService; |
|
|
|
|
private AsyncDispatcherTimer? progressQueryTimer; |
|
|
|
|
|
|
|
|
|
[ObservableProperty] |
|
|
|
|
private bool isGenerating; |
|
|
|
|
|
|
|
|
|
[ObservableProperty] |
|
|
|
|
private bool connectionFailed; |
|
|
|
|
|
|
|
|
|
[ObservableProperty] |
|
|
|
|
[NotifyPropertyChangedFor(nameof(ProgressRingVisibility))] |
|
|
|
|
[NotifyPropertyChangedFor(nameof(ImagePreviewVisibility))] |
|
|
|
@ -48,15 +55,67 @@ public partial class TextToImageViewModel : ObservableObject
|
|
|
|
|
|
|
|
|
|
public Visibility ProgressBarVisibility => ProgressValue > 0 ? Visibility.Visible : Visibility.Collapsed; |
|
|
|
|
|
|
|
|
|
public TextToImageViewModel(IA3WebApi a3WebApi, ILogger<TextToImageViewModel> logger) |
|
|
|
|
public TextToImageViewModel(IA3WebApi a3WebApi, ILogger<TextToImageViewModel> logger, IDialogErrorHandler dialogErrorHandler, PageContentDialogService pageContentDialogService) |
|
|
|
|
{ |
|
|
|
|
this.logger = logger; |
|
|
|
|
this.a3WebApi = a3WebApi; |
|
|
|
|
this.dialogErrorHandler = dialogErrorHandler; |
|
|
|
|
this.pageContentDialogService = pageContentDialogService; |
|
|
|
|
positivePromptText = "Positive"; |
|
|
|
|
negativePromptText = "Negative"; |
|
|
|
|
generationSteps = 10; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public async Task OnLoaded() |
|
|
|
|
{ |
|
|
|
|
if (ConnectionFailed) |
|
|
|
|
{ |
|
|
|
|
await PromptRetryConnection(); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
await CheckConnection(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Checks connection, if unsuccessful, shows a content dialog to retry |
|
|
|
|
private async Task CheckConnection() |
|
|
|
|
{ |
|
|
|
|
try |
|
|
|
|
{ |
|
|
|
|
await a3WebApi.GetPing(); |
|
|
|
|
ConnectionFailed = false; |
|
|
|
|
} |
|
|
|
|
catch (Exception e) |
|
|
|
|
{ |
|
|
|
|
// On error, show a content dialog to retry |
|
|
|
|
ConnectionFailed = true; |
|
|
|
|
logger.LogWarning("Ping response failed: {EMessage}", e.Message); |
|
|
|
|
var dialog = pageContentDialogService.CreateDialog(); |
|
|
|
|
dialog.Title = "Connection failed"; |
|
|
|
|
dialog.Content = "Please check the server is running with the --api launch option enabled."; |
|
|
|
|
dialog.CloseButtonText = "Retry"; |
|
|
|
|
dialog.IsPrimaryButtonEnabled = false; |
|
|
|
|
dialog.IsSecondaryButtonEnabled = false; |
|
|
|
|
await dialog.ShowAsync(); |
|
|
|
|
// Retry |
|
|
|
|
await CheckConnection(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private async Task PromptRetryConnection() |
|
|
|
|
{ |
|
|
|
|
var dialog = pageContentDialogService.CreateDialog(); |
|
|
|
|
dialog.Title = "Connection failed"; |
|
|
|
|
dialog.Content = "Please check the server is running with the --api launch option enabled."; |
|
|
|
|
dialog.CloseButtonText = "Retry"; |
|
|
|
|
dialog.IsPrimaryButtonEnabled = false; |
|
|
|
|
dialog.IsSecondaryButtonEnabled = false; |
|
|
|
|
await dialog.ShowAsync(); |
|
|
|
|
// Retry |
|
|
|
|
await CheckConnection(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private void StartProgressTracking(TimeSpan? interval = null) |
|
|
|
|
{ |
|
|
|
|
progressQueryTimer = new AsyncDispatcherTimer |
|
|
|
@ -78,7 +137,15 @@ public partial class TextToImageViewModel : ObservableObject
|
|
|
|
|
private async Task OnProgressTrackingTick() |
|
|
|
|
{ |
|
|
|
|
var request = new ProgressRequest(); |
|
|
|
|
var response = await a3WebApi.GetProgress(request); |
|
|
|
|
var task = a3WebApi.GetProgress(request); |
|
|
|
|
var responseResult = await dialogErrorHandler.TryAsync(task, "Failed to get progress"); |
|
|
|
|
if (!responseResult.IsSuccessful || responseResult.Result == null) |
|
|
|
|
{ |
|
|
|
|
StopProgressTracking(); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var response = responseResult.Result; |
|
|
|
|
var progress = response.Progress; |
|
|
|
|
logger.LogInformation("Image Progress: {ResponseProgress}, ETA: {ResponseEtaRelative} s", response.Progress, response.EtaRelative); |
|
|
|
|
if (Math.Abs(progress - 1.0) < 0.01) |
|
|
|
@ -133,11 +200,13 @@ public partial class TextToImageViewModel : ObservableObject
|
|
|
|
|
|
|
|
|
|
// Progress track while waiting for response |
|
|
|
|
StartProgressTracking(); |
|
|
|
|
var response = await task; |
|
|
|
|
var response = await dialogErrorHandler.TryAsync(task, "Failed to get a response from the server"); |
|
|
|
|
StopProgressTracking(); |
|
|
|
|
|
|
|
|
|
if (!response.IsSuccessful || response.Result == null) return; |
|
|
|
|
|
|
|
|
|
// Decode base64 image |
|
|
|
|
var result = response.Images[0]; |
|
|
|
|
var result = response.Result.Images[0]; |
|
|
|
|
var bitmap = Base64ToBitmap(result); |
|
|
|
|
|
|
|
|
|
ImagePreview = bitmap; |
|
|
|
|