using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.Controls.Notifications;
using Avalonia.Platform.Storage;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using DynamicData.Binding;
using FluentAvalonia.UI.Controls;
using NLog;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.Views.Settings;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
using Symbol = FluentIcons.Common.Symbol;
using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels.Settings;
[View(typeof(InferenceSettingsPage))]
[Singleton, ManagedService]
public partial class InferenceSettingsViewModel : PageViewModelBase
{
private readonly INotificationService notificationService;
private readonly ISettingsManager settingsManager;
private readonly ICompletionProvider completionProvider;
///
public override string Title => "Inference";
///
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.Settings, IsFilled = true };
[ObservableProperty]
private bool isPromptCompletionEnabled = true;
[ObservableProperty]
private IReadOnlyList availableTagCompletionCsvs = Array.Empty();
[ObservableProperty]
private string? selectedTagCompletionCsv;
[ObservableProperty]
private bool isCompletionRemoveUnderscoresEnabled = true;
[ObservableProperty]
[CustomValidation(typeof(InferenceSettingsViewModel), nameof(ValidateOutputImageFileNameFormat))]
private string? outputImageFileNameFormat;
[ObservableProperty]
private string? outputImageFileNameFormatSample;
public IEnumerable OutputImageFileNameFormatVars =>
FileNameFormatProvider
.GetSample()
.Substitutions
.Select(kv => new FileNameFormatVar { Variable = $"{{{kv.Key}}}", Example = kv.Value.Invoke() });
[ObservableProperty]
private bool isImageViewerPixelGridEnabled = true;
public InferenceSettingsViewModel(
INotificationService notificationService,
IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner,
ServiceManager dialogFactory,
ICompletionProvider completionProvider,
ITrackedDownloadService trackedDownloadService,
IModelIndexService modelIndexService,
INavigationService settingsNavigationService,
IAccountsService accountsService,
ISettingsManager settingsManager
)
{
this.settingsManager = settingsManager;
this.notificationService = notificationService;
this.completionProvider = completionProvider;
settingsManager.RelayPropertyFor(
this,
vm => vm.SelectedTagCompletionCsv,
settings => settings.TagCompletionCsv
);
settingsManager.RelayPropertyFor(
this,
vm => vm.IsPromptCompletionEnabled,
settings => settings.IsPromptCompletionEnabled,
true
);
settingsManager.RelayPropertyFor(
this,
vm => vm.IsCompletionRemoveUnderscoresEnabled,
settings => settings.IsCompletionRemoveUnderscoresEnabled,
true
);
this.WhenPropertyChanged(vm => vm.OutputImageFileNameFormat)
.Throttle(TimeSpan.FromMilliseconds(50))
.Subscribe(formatProperty =>
{
var provider = FileNameFormatProvider.GetSample();
var template = formatProperty.Value ?? string.Empty;
if (!string.IsNullOrEmpty(template) && provider.Validate(template) == ValidationResult.Success)
{
var format = FileNameFormat.Parse(template, provider);
OutputImageFileNameFormatSample = format.GetFileName() + ".png";
}
else
{
// Use default format if empty
var defaultFormat = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, provider);
OutputImageFileNameFormatSample = defaultFormat.GetFileName() + ".png";
}
});
settingsManager.RelayPropertyFor(
this,
vm => vm.OutputImageFileNameFormat,
settings => settings.InferenceOutputImageFileNameFormat,
true
);
settingsManager.RelayPropertyFor(
this,
vm => vm.IsImageViewerPixelGridEnabled,
settings => settings.IsImageViewerPixelGridEnabled,
true
);
ImportTagCsvCommand.WithNotificationErrorHandler(notificationService, LogLevel.Warn);
}
///
/// Validator for
///
public static ValidationResult ValidateOutputImageFileNameFormat(string? format, ValidationContext context)
{
return FileNameFormatProvider.GetSample().Validate(format ?? string.Empty);
}
///
public override void OnLoaded()
{
base.OnLoaded();
UpdateAvailableTagCompletionCsvs();
}
#region Commands
[RelayCommand(FlowExceptionsToTaskScheduler = true)]
private async Task ImportTagCsv()
{
var storage = App.StorageProvider;
var files = await storage.OpenFilePickerAsync(
new FilePickerOpenOptions
{
FileTypeFilter = new List { new("CSV") { Patterns = ["*.csv"] } }
}
);
if (files.Count == 0)
return;
var sourceFile = new FilePath(files[0].TryGetLocalPath()!);
var tagsDir = settingsManager.TagsDirectory;
tagsDir.Create();
// Copy to tags directory
var targetFile = tagsDir.JoinFile(sourceFile.Name);
await sourceFile.CopyToAsync(targetFile);
// Update index
UpdateAvailableTagCompletionCsvs();
// Trigger load
completionProvider.BackgroundLoadFromFile(targetFile, true);
notificationService.Show(
$"Imported {sourceFile.Name}",
$"The {sourceFile.Name} file has been imported.",
NotificationType.Success
);
}
#endregion
private void UpdateAvailableTagCompletionCsvs()
{
if (!settingsManager.IsLibraryDirSet)
return;
if (settingsManager.TagsDirectory is not { Exists: true } tagsDir)
return;
var csvFiles = tagsDir.Info.EnumerateFiles("*.csv");
AvailableTagCompletionCsvs = csvFiles.Select(f => f.Name).ToImmutableArray();
// Set selected to current if exists
var settingsCsv = settingsManager.Settings.TagCompletionCsv;
if (settingsCsv is not null && AvailableTagCompletionCsvs.Contains(settingsCsv))
{
SelectedTagCompletionCsv = settingsCsv;
}
}
}