diff --git a/CHANGELOG.md b/CHANGELOG.md index f69ee68d..bce62ccd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,42 @@ All notable changes to Stability Matrix will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2.0.0.html). +## v2.6.0 +### Added +- Added **Output Sharing** option for all packages in the three-dots menu on the Packages page + - This will link the package's output folders to the relevant subfolders in the "Outputs" directory + - When a package only has a generic "outputs" folder, all generated images from that package will be linked to the "Outputs\Text2Img" folder when this option is enabled +- Added **Outputs page** for viewing generated images from any package, or the shared output folder +- Added [Stable Diffusion WebUI/UX](https://github.com/anapnoe/stable-diffusion-webui-ux) package +- Added [Stable Diffusion WebUI-DirectML](https://github.com/lshqqytiger/stable-diffusion-webui-directml) package +- Added [kohya_ss](https://github.com/bmaltais/kohya_ss) package +- Added [Fooocus-ControlNet-SDXL](https://github.com/fenneishi/Fooocus-ControlNet-SDXL) package +- Added GPU compatibility badges to the installers +- Added filtering of "incompatible" packages (ones that do not support your GPU) to all installers + - This can be overridden by checking the new "Show All Packages" checkbox +- Added more launch options for Fooocus, such as the `--preset` option +- Added Ctrl+ScrollWheel to change image size in the inference output gallery and new Outputs page +- Added "No Images Found" placeholder for non-connected models on the Checkpoints tab +### Changed +- If ComfyUI for Inference is chosen during the One-Click Installer, the Inference page will be opened after installation instead of the Launch page +- Changed all package installs & updates to use git commands instead of downloading zip files +- The One-Click Installer now uses the new progress dialog with console +- NVIDIA GPU users will be updated to use CUDA 12.1 for ComfyUI & Fooocus packages for a slight performance improvement + - Update will occur the next time the package is updated, or on a fresh install + - Note: CUDA 12.1 is only available on Maxwell (GTX 900 series) and newer GPUs +- Improved Model Browser download stability with automatic retries for download errors +- Optimized page navigation and syntax formatting configurations to improve startup time +### Fixed +- Fixed crash when clicking Inference gallery image after the image is deleted externally in file explorer +- Fixed Inference popup Install button not working on One-Click Installer +- Fixed Inference Prompt Completion window sometimes not showing while typing +- Fixed "Show Model Images" toggle on Checkpoints page sometimes displaying cut-off model images +- Fixed missing httpx package during Automatic1111 install +- Fixed some instances of localized text being cut off from controls being too small + ## v2.5.7 ### Fixed -- Fixed error `got an unexpected keyword argument 'socket_options'` on fresh installs of Automatic1111 Stable Diffusion WebUI -due to missing httpx dependency specification from gradio +- Fixed error `got an unexpected keyword argument 'socket_options'` on fresh installs of Automatic1111 Stable Diffusion WebUI due to missing httpx dependency specification from gradio ## v2.5.6 ### Added diff --git a/Jenkinsfile b/Jenkinsfile index 7b493b1a..51dddff3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -35,11 +35,6 @@ node("Diligence") { stage('Publish Linux') { sh "/home/jenkins/.dotnet/tools/pupnet --runtime linux-x64 --kind appimage --app-version ${version} --clean -y" } - - stage ('Archive Artifacts') { - archiveArtifacts artifacts: 'out/*.exe', followSymlinks: false - archiveArtifacts artifacts: 'Release/linux-x64/*.AppImage', followSymlinks: false - } } } finally { stage('Cleanup') { diff --git a/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj b/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj index cfbcd729..75091167 100644 --- a/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj +++ b/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj @@ -19,12 +19,12 @@ - - + + - - + + diff --git a/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs b/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs index fef4f8db..66a66152 100644 --- a/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs +++ b/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs @@ -1,12 +1,10 @@ using System; using System.Threading; -using AsyncAwaitBestPractices; using Avalonia; using Avalonia.Animation; using Avalonia.Animation.Easings; using Avalonia.Media; using Avalonia.Styling; -using FluentAvalonia.UI.Media.Animation; namespace StabilityMatrix.Avalonia.Animations; @@ -23,7 +21,7 @@ public class BetterEntranceNavigationTransition : BaseTransitionInfo /// Gets or sets the Vertical Offset used when animating /// public double FromVerticalOffset { get; set; } = 100; - + public override async void RunAnimation(Animatable ctrl, CancellationToken cancellationToken) { var animation = new Animation @@ -36,7 +34,7 @@ public class BetterEntranceNavigationTransition : BaseTransitionInfo Setters = { new Setter(Visual.OpacityProperty, 0.0), - new Setter(TranslateTransform.XProperty,FromHorizontalOffset), + new Setter(TranslateTransform.XProperty, FromHorizontalOffset), new Setter(TranslateTransform.YProperty, FromVerticalOffset) }, Cue = new Cue(0d) @@ -46,7 +44,7 @@ public class BetterEntranceNavigationTransition : BaseTransitionInfo Setters = { new Setter(Visual.OpacityProperty, 1d), - new Setter(TranslateTransform.XProperty,0.0), + new Setter(TranslateTransform.XProperty, 0.0), new Setter(TranslateTransform.YProperty, 0.0) }, Cue = new Cue(1d) diff --git a/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs b/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs index 507d9872..7014e63f 100644 --- a/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs +++ b/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs @@ -2,7 +2,6 @@ using Avalonia; using Avalonia.Controls; using Avalonia.Rendering.Composition; -using Avalonia.Rendering.Composition.Animations; namespace StabilityMatrix.Avalonia.Animations; diff --git a/StabilityMatrix.Avalonia/App.axaml b/StabilityMatrix.Avalonia/App.axaml index 0258b36b..f33a4d5c 100644 --- a/StabilityMatrix.Avalonia/App.axaml +++ b/StabilityMatrix.Avalonia/App.axaml @@ -20,6 +20,8 @@ + + @@ -56,10 +58,10 @@ - + --> - - + + + + + + + + + + + + + + + + + + - + - + FontSize="12" + Text="{Binding FileNameWithoutExtension}" + TextAlignment="Center" + TextTrimming="CharacterEllipsis" /> @@ -206,10 +229,10 @@ - + - - + + diff --git a/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs index b36e1fe1..08a9cca5 100644 --- a/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs @@ -1,9 +1,23 @@ -using Avalonia.Input; +using Avalonia.Controls; +using Avalonia.Controls.Primitives; +using Avalonia.Input; +using StabilityMatrix.Avalonia.ViewModels.Inference; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models.Settings; namespace StabilityMatrix.Avalonia.Controls; +[Transient] public class ImageFolderCard : DropTargetTemplatedControlBase { + private ItemsRepeater? imageRepeater; + + protected override void OnApplyTemplate(TemplateAppliedEventArgs e) + { + imageRepeater = e.NameScope.Find("ImageRepeater"); + base.OnApplyTemplate(e); + } + /// protected override void DropHandler(object? sender, DragEventArgs e) { @@ -17,4 +31,30 @@ public class ImageFolderCard : DropTargetTemplatedControlBase base.DragOverHandler(sender, e); e.Handled = true; } + + protected override void OnPointerWheelChanged(PointerWheelEventArgs e) + { + if (e.KeyModifiers != KeyModifiers.Control) + return; + if (DataContext is not ImageFolderCardViewModel vm) + return; + + if (e.Delta.Y > 0) + { + if (vm.ImageSize.Height >= 500) + return; + vm.ImageSize += new Size(15, 19); + } + else + { + if (vm.ImageSize.Height <= 200) + return; + vm.ImageSize -= new Size(15, 19); + } + + imageRepeater?.InvalidateArrange(); + imageRepeater?.InvalidateMeasure(); + + e.Handled = true; + } } diff --git a/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml b/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml index 3fe4071e..6fb6e844 100644 --- a/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml +++ b/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml @@ -13,11 +13,6 @@ - - - + + + - - + + + + + + + + + + + + + + + - + IsVisible="{Binding IsRefinerSelectionEnabled}" + Text="{x:Static lang:Resources.Label_Refiner}" + TextAlignment="Left" /> + - - + + - + IsVisible="{Binding IsVaeSelectionEnabled}" + Text="{x:Static lang:Resources.Label_VAE}" + TextAlignment="Left" /> + - + diff --git a/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs index e997a874..622f1cbe 100644 --- a/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs @@ -1,9 +1,7 @@ -using Avalonia; -using Avalonia.Controls; -using Avalonia.Controls.Primitives; +using Avalonia.Controls.Primitives; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; -public class ModelCard : TemplatedControl -{ -} \ No newline at end of file +[Transient] +public class ModelCard : TemplatedControl { } diff --git a/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs b/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs index 9342ec12..ef1f1114 100644 --- a/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs @@ -2,7 +2,6 @@ using System.Windows.Input; using Avalonia; using Avalonia.Controls.Primitives; -using AvaloniaEdit.Utils; using CommunityToolkit.Mvvm.Input; namespace StabilityMatrix.Avalonia.Controls; diff --git a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs index 47a02af1..f5a79dcb 100644 --- a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs @@ -6,9 +6,12 @@ using AvaloniaEdit; using AvaloniaEdit.Editing; using AvaloniaEdit.Utils; using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; +[Transient] public class PromptCard : TemplatedControl { /// @@ -31,7 +34,7 @@ public class PromptCard : TemplatedControl { if (editor is not null) { - TextEditorConfigs.ConfigForPrompt(editor); + TextEditorConfigs.Configure(editor, TextEditorPreset.Prompt); editor.TextArea.Margin = new Thickness(0, 0, 4, 0); if (editor.TextArea.ActiveInputHandler is TextAreaInputHandler inputHandler) diff --git a/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs b/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs index f5620f12..1943c3ae 100644 --- a/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs @@ -1,7 +1,9 @@ using Avalonia.Markup.Xaml; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; +[Transient] public partial class RefreshBadge : UserControlBase { public RefreshBadge() diff --git a/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs index 7e732e0f..ace972b1 100644 --- a/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs @@ -1,7 +1,7 @@ using Avalonia.Controls.Primitives; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; -public class SamplerCard : TemplatedControl -{ -} +[Transient] +public class SamplerCard : TemplatedControl { } diff --git a/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollContentPresenter.cs b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollContentPresenter.cs new file mode 100644 index 00000000..a5f30ff1 --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollContentPresenter.cs @@ -0,0 +1,14 @@ +using Avalonia.Controls.Presenters; +using Avalonia.Input; + +namespace StabilityMatrix.Avalonia.Controls.Scroll; + +public class BetterScrollContentPresenter : ScrollContentPresenter +{ + protected override void OnPointerWheelChanged(PointerWheelEventArgs e) + { + if (e.KeyModifiers == KeyModifiers.Control) + return; + base.OnPointerWheelChanged(e); + } +} diff --git a/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.axaml b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.axaml new file mode 100644 index 00000000..48f59f05 --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.axaml @@ -0,0 +1,69 @@ + + + + + Item 1 + Item 2 + Item 3 + Item 4 + Item 5 + Item 6 + Item 7 + Item 8 + Item 9 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.cs b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.cs new file mode 100644 index 00000000..05adc971 --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.cs @@ -0,0 +1,5 @@ +using Avalonia.Controls; + +namespace StabilityMatrix.Avalonia.Controls.Scroll; + +public class BetterScrollViewer : ScrollViewer { } diff --git a/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs index f08853dd..06d164e9 100644 --- a/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs @@ -1,7 +1,7 @@ using Avalonia.Controls.Primitives; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; -public class SeedCard : TemplatedControl -{ -} +[Transient] +public class SeedCard : TemplatedControl { } diff --git a/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs index dd748696..726cf414 100644 --- a/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs @@ -1,3 +1,6 @@ -namespace StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Core.Attributes; +namespace StabilityMatrix.Avalonia.Controls; + +[Transient] public class SelectImageCard : DropTargetTemplatedControlBase { } diff --git a/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml new file mode 100644 index 00000000..c94276b1 --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.cs b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.cs new file mode 100644 index 00000000..7878bf8b --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.cs @@ -0,0 +1,56 @@ +using System; +using AsyncImageLoader; +using Avalonia; +using Avalonia.Controls; +using Avalonia.Controls.Primitives; + +namespace StabilityMatrix.Avalonia.Controls.SelectableImageCard; + +public class SelectableImageButton : Button +{ + public static readonly StyledProperty IsSelectedProperty = + ToggleButton.IsCheckedProperty.AddOwner(); + + public static readonly StyledProperty SourceProperty = + AdvancedImage.SourceProperty.AddOwner(); + + public static readonly StyledProperty ImageWidthProperty = AvaloniaProperty.Register< + SelectableImageButton, + double + >("ImageWidth", 300); + + public static readonly StyledProperty ImageHeightProperty = AvaloniaProperty.Register< + SelectableImageButton, + double + >("ImageHeight", 300); + + static SelectableImageButton() + { + AffectsRender(ImageWidthProperty, ImageHeightProperty); + AffectsArrange(ImageWidthProperty, ImageHeightProperty); + } + + public double ImageHeight + { + get => GetValue(ImageHeightProperty); + set => SetValue(ImageHeightProperty, value); + } + + public double ImageWidth + { + get => GetValue(ImageWidthProperty); + set => SetValue(ImageWidthProperty, value); + } + + public bool? IsSelected + { + get => GetValue(IsSelectedProperty); + set => SetValue(IsSelectedProperty, value); + } + + public string? Source + { + get => GetValue(SourceProperty); + set => SetValue(SourceProperty, value); + } +} diff --git a/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs index 6fd22370..4c2a250d 100644 --- a/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs @@ -1,7 +1,7 @@ -using Avalonia; -using Avalonia.Controls; -using Avalonia.Controls.Primitives; +using Avalonia.Controls.Primitives; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; +[Transient] public class SharpenCard : TemplatedControl { } diff --git a/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs index c8374349..2341a96d 100644 --- a/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs @@ -1,14 +1,18 @@ using System.Diagnostics.CodeAnalysis; using Avalonia; using Avalonia.Controls.Primitives; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] +[Transient] public class StackCard : TemplatedControl { - public static readonly StyledProperty SpacingProperty = AvaloniaProperty.Register( - "Spacing", 8); + public static readonly StyledProperty SpacingProperty = AvaloniaProperty.Register< + StackCard, + int + >("Spacing", 8); public int Spacing { diff --git a/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs b/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs index 3fb2304b..38208891 100644 --- a/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs @@ -1,8 +1,10 @@ using Avalonia; using Avalonia.Controls.Primitives; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; +[Transient] public class StackExpander : TemplatedControl { public static readonly StyledProperty SpacingProperty = AvaloniaProperty.Register< diff --git a/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs index 083d09fa..7e2e2327 100644 --- a/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs @@ -1,13 +1,14 @@ -using System; -using AsyncAwaitBestPractices; +using AsyncAwaitBestPractices; using Avalonia.Controls; using Avalonia.Controls.Primitives; using FluentAvalonia.UI.Controls; using StabilityMatrix.Avalonia.ViewModels.Inference; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models.Api.Comfy; namespace StabilityMatrix.Avalonia.Controls; +[Transient] public class UpscalerCard : TemplatedControl { /// diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs index 43d8e751..4d785605 100644 --- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs +++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs @@ -1,14 +1,16 @@ using System; using System.Collections.Generic; -using System.Collections.Immutable; using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Http; using System.Text; using AvaloniaEdit.Utils; +using DynamicData; +using DynamicData.Binding; using Microsoft.Extensions.DependencyInjection; -using StabilityMatrix.Avalonia.Controls; +using NSubstitute; +using NSubstitute.ReturnsExtensions; using StabilityMatrix.Avalonia.Controls.CodeCompletion; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.TagCompletion; @@ -20,6 +22,7 @@ using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Progress; using StabilityMatrix.Avalonia.ViewModels.Inference; +using StabilityMatrix.Avalonia.ViewModels.OutputsPage; using StabilityMatrix.Avalonia.ViewModels.Settings; using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Database; @@ -29,7 +32,9 @@ using StabilityMatrix.Core.Helper.Factory; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api; using StabilityMatrix.Core.Models.Api.Comfy; +using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.PackageModification; +using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; @@ -52,7 +57,7 @@ public static class DesignData if (isInitialized) throw new InvalidOperationException("DesignData is already initialized."); - var services = new ServiceCollection(); + var services = App.ConfigureServices(); var activePackageId = Guid.NewGuid(); services.AddSingleton( @@ -106,18 +111,18 @@ public static class DesignData // Mock services services - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) .AddSingleton() - .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton(); + .AddSingleton(); // Placeholder services that nobody should need during design time services @@ -128,12 +133,6 @@ public static class DesignData .AddSingleton(_ => null!) .AddSingleton(_ => null!); - // Using some default service implementations from App - App.ConfigurePackages(services); - App.ConfigurePageViewModels(services); - App.ConfigureDialogViewModels(services); - App.ConfigureViews(services); - // Override Launch page with mock services.Remove(ServiceDescriptor.Singleton()); services.AddSingleton(); @@ -172,12 +171,61 @@ public static class DesignData LaunchOptionsViewModel.UpdateFilterCards(); InstallerViewModel = Services.GetRequiredService(); - InstallerViewModel.AvailablePackages = packageFactory - .GetAllAvailablePackages() - .ToImmutableArray(); + InstallerViewModel.AvailablePackages = new ObservableCollectionExtended( + packageFactory.GetAllAvailablePackages() + ); InstallerViewModel.SelectedPackage = InstallerViewModel.AvailablePackages[0]; InstallerViewModel.ReleaseNotes = "## Release Notes\nThis is a test release note."; + ObservableCacheEx.AddOrUpdate( + CheckpointsPageViewModel.CheckpointFoldersCache, + new CheckpointFolder[] + { + new(settingsManager, downloadService, modelFinder, notificationService) + { + DirectoryPath = "Models/StableDiffusion", + DisplayedCheckpointFiles = new ObservableCollectionExtended() + { + new() + { + FilePath = "~/Models/StableDiffusion/electricity-light.safetensors", + Title = "Auroral Background", + PreviewImagePath = + "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/" + + "78fd2a0a-42b6-42b0-9815-81cb11bb3d05/00009-2423234823.jpeg", + ConnectedModel = new ConnectedModelInfo + { + VersionName = "Lightning Auroral", + BaseModel = "SD 1.5", + ModelName = "Auroral Background", + ModelType = CivitModelType.Model, + FileMetadata = new CivitFileMetadata + { + Format = CivitModelFormat.SafeTensor, + Fp = CivitModelFpType.fp16, + Size = CivitModelSize.pruned, + } + } + }, + new() + { + FilePath = "~/Models/Lora/model.safetensors", + Title = "Some model" + }, + }, + }, + new(settingsManager, downloadService, modelFinder, notificationService) + { + Title = "Lora", + DirectoryPath = "Packages/Lora", + DisplayedCheckpointFiles = new ObservableCollectionExtended + { + new() { FilePath = "~/Models/Lora/lora_v2.pt", Title = "Best Lora v2", } + } + } + } + ); + /*// Checkpoints page CheckpointsPageViewModel.CheckpointFolders = new CheckpointFolder[] @@ -336,6 +384,26 @@ public static class DesignData public static LaunchPageViewModel LaunchPageViewModel => Services.GetRequiredService(); + public static OutputsPageViewModel OutputsPageViewModel + { + get + { + var vm = Services.GetRequiredService(); + vm.Outputs = new ObservableCollectionExtended + { + new( + new LocalImageFile + { + AbsolutePath = + "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/78fd2a0a-42b6-42b0-9815-81cb11bb3d05/00009-2423234823.jpeg", + ImageType = LocalImageFileType.TextToImage + } + ) + }; + return vm; + } + } + public static PackageManagerViewModel PackageManagerViewModel { get @@ -465,6 +533,15 @@ The gallery images are often inpainted, but you will get something very similar viewModel.EnvVars = new ObservableCollection { new("UWU", "TRUE"), }; }); + public static PythonPackagesViewModel PythonPackagesViewModel => + DialogFactory.Get(vm => + { + vm.AddPackages( + new PipPackageInfo("pip", "1.0.0"), + new PipPackageInfo("torch", "2.1.0+cu121") + ); + }); + public static InferenceTextToImageViewModel InferenceTextToImageViewModel => DialogFactory.Get(vm => { @@ -603,8 +680,8 @@ The gallery images are often inpainted, but you will get something very similar get { var list = new CompletionList { IsFiltering = true }; - list.CompletionData.AddRange(SampleCompletionData); - list.FilteredCompletionData.AddRange(list.CompletionData); + ExtensionMethods.AddRange(list.CompletionData, SampleCompletionData); + ExtensionMethods.AddRange(list.FilteredCompletionData, list.CompletionData); list.SelectItem("te", true); return list; } diff --git a/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs b/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs deleted file mode 100644 index 6bffe945..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using StabilityMatrix.Core.Api; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockApiFactory : IApiFactory -{ - public T CreateRefitClient(Uri baseAddress) - { - throw new NotImplementedException(); - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs b/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs deleted file mode 100644 index 79851cd7..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs +++ /dev/null @@ -1,18 +0,0 @@ -using System; -using StabilityMatrix.Avalonia.Services; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockDiscordRichPresenceService : IDiscordRichPresenceService -{ - /// - public void Dispose() - { - GC.SuppressFinalize(this); - } - - /// - public void UpdateState() - { - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs deleted file mode 100644 index b1e08d8b..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs +++ /dev/null @@ -1,50 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using StabilityMatrix.Core.Models.Progress; -using StabilityMatrix.Core.Services; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockDownloadService : IDownloadService -{ - public Task DownloadToFileAsync( - string downloadUrl, - string downloadPath, - IProgress? progress = null, - string? httpClientName = null, - CancellationToken cancellationToken = default - ) - { - return Task.CompletedTask; - } - - /// - public Task ResumeDownloadToFileAsync( - string downloadUrl, - string downloadPath, - long existingFileSize, - IProgress? progress = null, - string? httpClientName = null, - CancellationToken cancellationToken = default - ) - { - return Task.CompletedTask; - } - - /// - public Task GetFileSizeAsync( - string downloadUrl, - string? httpClientName = null, - CancellationToken cancellationToken = default - ) - { - return Task.FromResult(0L); - } - - public Task GetImageStreamFromUrl(string url) - { - return Task.FromResult(new MemoryStream(new byte[24]) as Stream)!; - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs b/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs deleted file mode 100644 index 452af8c4..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Net.Http; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockHttpClientFactory : IHttpClientFactory -{ - public HttpClient CreateClient(string name) - { - throw new NotImplementedException(); - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs b/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs index d5ee1650..39474975 100644 --- a/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs +++ b/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs @@ -1,8 +1,8 @@ -using System.Collections.Generic; using System.Threading.Tasks; +using DynamicData; +using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Database; -using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.DesignData; @@ -10,47 +10,50 @@ namespace StabilityMatrix.Avalonia.DesignData; public class MockImageIndexService : IImageIndexService { /// - public IndexCollection InferenceImages { get; } = - new IndexCollection(null!, file => file.RelativePath) + public IndexCollection InferenceImages { get; } + + public MockImageIndexService() + { + InferenceImages = new IndexCollection( + this, + file => file.AbsolutePath + ) { - RelativePath = "inference" + RelativePath = "Inference" }; + } /// - public Task> GetLocalImagesByPrefix(string pathPrefix) + public Task RefreshIndexForAllCollections() { - return Task.FromResult( - (IReadOnlyList) - new LocalImageFile[] - { - new() - { - RelativePath = - "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/4a7e00a7-6f18-42d4-87c0-10e792df2640/width=1152", - }, - new() - { - RelativePath = - "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/a318ac1f-3ad0-48ac-98cc-79126febcc17/width=1024", - }, - new() - { - RelativePath = - "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/16588c94-6595-4be9-8806-d7e6e22d198c/width=1152", - } - } - ); + return RefreshIndex(InferenceImages); } - /// - public Task RefreshIndexForAllCollections() + private static LocalImageFile GetSampleImage(string url) { - return Task.CompletedTask; + return new LocalImageFile + { + AbsolutePath = url, + GenerationParameters = GenerationParameters.GetSample(), + ImageSize = new System.Drawing.Size(1024, 1024) + }; } /// public Task RefreshIndex(IndexCollection indexCollection) { + var toAdd = new[] + { + GetSampleImage( + "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/a318ac1f-3ad0-48ac-98cc-79126febcc17/width=1024" + ), + GetSampleImage( + "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/16588c94-6595-4be9-8806-d7e6e22d198c/width=1152" + ) + }; + + indexCollection.ItemsSource.EditDiff(toAdd); + return Task.CompletedTask; } @@ -59,10 +62,4 @@ public class MockImageIndexService : IImageIndexService { throw new System.NotImplementedException(); } - - /// - public Task RemoveImage(LocalImageFile imageFile) - { - throw new System.NotImplementedException(); - } } diff --git a/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs b/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs index dffaee55..70b928f4 100644 --- a/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; using CommunityToolkit.Mvvm.ComponentModel; +using DynamicData; using DynamicData.Binding; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; @@ -48,6 +49,17 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl /// public bool CanUserDisconnect => IsConnected && !IsConnecting; + public MockInferenceClientManager() + { + Models.AddRange( + new[] + { + HybridModelFile.FromRemote("v1-5-pruned-emaonly.safetensors"), + HybridModelFile.FromRemote("artshaper1.safetensors"), + } + ); + } + /// public Task CopyImageToInputAsync( FilePath imageFile, diff --git a/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs b/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs deleted file mode 100644 index a8f2996e..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using LiteDB.Async; -using StabilityMatrix.Core.Database; -using StabilityMatrix.Core.Models.Api; -using StabilityMatrix.Core.Models.Database; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockLiteDbContext : ILiteDbContext -{ - public LiteDatabaseAsync Database => throw new NotImplementedException(); - public ILiteCollectionAsync CivitModels => throw new NotImplementedException(); - public ILiteCollectionAsync CivitModelVersions => - throw new NotImplementedException(); - public ILiteCollectionAsync CivitModelQueryCache => - throw new NotImplementedException(); - public ILiteCollectionAsync LocalModelFiles => - throw new NotImplementedException(); - public ILiteCollectionAsync InferenceProjects => - throw new NotImplementedException(); - public ILiteCollectionAsync LocalImageFiles => - throw new NotImplementedException(); - - public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync( - string hashBlake3 - ) - { - return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null)); - } - - public Task UpsertCivitModelAsync(CivitModel civitModel) - { - return Task.FromResult(true); - } - - public Task UpsertCivitModelAsync(IEnumerable civitModels) - { - return Task.FromResult(true); - } - - public Task UpsertCivitModelQueryCacheEntryAsync(CivitModelQueryCacheEntry entry) - { - return Task.FromResult(true); - } - - public Task GetGithubCacheEntry(string cacheKey) - { - return Task.FromResult(null); - } - - public Task UpsertGithubCacheEntry(GithubCacheEntry cacheEntry) - { - return Task.FromResult(true); - } - - public void Dispose() - { - GC.SuppressFinalize(this); - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs b/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs deleted file mode 100644 index 4fdde7ad..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs +++ /dev/null @@ -1,47 +0,0 @@ -using System; -using System.Threading.Tasks; -using Avalonia; -using Avalonia.Controls.Notifications; -using StabilityMatrix.Avalonia.Services; -using StabilityMatrix.Core.Models; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockNotificationService : INotificationService -{ - public void Initialize(Visual? visual, - NotificationPosition position = NotificationPosition.BottomRight, int maxItems = 3) - { - } - - public void Show(INotification notification) - { - } - - public Task> TryAsync(Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error) - { - return Task.FromResult(new TaskResult(default!)); - } - - public Task> TryAsync(Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error) - { - return Task.FromResult(new TaskResult(true)); - } - - public void Show( - string title, - string message, - NotificationType appearance = NotificationType.Information, - TimeSpan? expiration = null) - { - } - - public void ShowPersistent( - string title, - string message, - NotificationType appearance = NotificationType.Information) - { - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs b/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs deleted file mode 100644 index 35c21160..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System.Threading.Tasks; -using StabilityMatrix.Core.Helper; -using StabilityMatrix.Core.Models.FileInterfaces; -using StabilityMatrix.Core.Models.Packages; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockSharedFolders : ISharedFolders -{ - public void SetupLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory) - { - } - - public Task UpdateLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory) - { - return Task.CompletedTask; - } - - public void RemoveLinksForAllPackages() - { - } - - public void SetupSharedModelFolders() - { - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs deleted file mode 100644 index 4522bd2d..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; -using System.Collections.Generic; -using StabilityMatrix.Core.Models; -using StabilityMatrix.Core.Models.FileInterfaces; -using StabilityMatrix.Core.Services; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockTrackedDownloadService : ITrackedDownloadService -{ - /// - public IEnumerable Downloads => Array.Empty(); - - /// - public event EventHandler? DownloadAdded; - - /// - public TrackedDownload NewDownload(Uri downloadUrl, FilePath downloadPath) - { - throw new NotImplementedException(); - } -} diff --git a/StabilityMatrix.Avalonia/DialogHelper.cs b/StabilityMatrix.Avalonia/DialogHelper.cs index 15397e48..d99bc8dc 100644 --- a/StabilityMatrix.Avalonia/DialogHelper.cs +++ b/StabilityMatrix.Avalonia/DialogHelper.cs @@ -9,16 +9,15 @@ using System.Text.Json; using Avalonia; using Avalonia.Controls; using Avalonia.Data; +using Avalonia.Layout; using Avalonia.LogicalTree; using Avalonia.Media; using Avalonia.Threading; using AvaloniaEdit; using AvaloniaEdit.TextMate; -using CommunityToolkit.Mvvm.ComponentModel.__Internals; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using Markdown.Avalonia; -using Markdown.Avalonia.SyntaxHigh.Extensions; using NLog; using Refit; using StabilityMatrix.Avalonia.Controls; @@ -26,7 +25,6 @@ using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; -using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Services; using TextMateSharp.Grammars; using Process = FuzzySharp.Process; @@ -96,6 +94,18 @@ public static class DialogHelper Watermark = field.Watermark, DataContext = field, }; + + if (!string.IsNullOrEmpty(field.InnerLeftText)) + { + textBox.InnerLeftContent = new TextBlock() + { + Text = field.InnerLeftText, + Foreground = Brushes.Gray, + VerticalAlignment = VerticalAlignment.Center, + Margin = new Thickness(8, 0, -4, 0) + }; + } + stackPanel.Children.Add(textBox); // When IsValid property changes, update invalid count and primary button @@ -427,7 +437,7 @@ public static class DialogHelper AllowScrollBelowDocument = false } }; - TextEditorConfigs.ConfigForPrompt(textEditor); + TextEditorConfigs.Configure(textEditor, TextEditorPreset.Prompt); textEditor.Document.Text = errorLineFormatted; textEditor.TextArea.Caret.Offset = textEditor.Document.Lines[0].EndOffset; @@ -518,15 +528,6 @@ public static class DialogHelper XamlRoot = App.VisualRoot }; } - - /// - /// Creates a connection help dialog. - /// - public static BetterContentDialog CreateConnectionHelpDialog() - { - // TODO - return new BetterContentDialog(); - } } // Text fields @@ -541,6 +542,9 @@ public sealed class TextBoxField : INotifyPropertyChanged // Watermark text public string Watermark { get; init; } = string.Empty; + // Inner left value + public string? InnerLeftText { get; init; } + /// /// Validation action on text changes. Throw exception if invalid. /// diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs index a6303f21..81691e93 100644 --- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs +++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs @@ -184,7 +184,7 @@ public static class ComfyNodeBuilderExtensions var checkpointLoader = builder.Nodes.AddNamedNode( ComfyNodeBuilder.CheckpointLoaderSimple( "Refiner_CheckpointLoader", - modelCardViewModel.SelectedRefiner?.FileName + modelCardViewModel.SelectedRefiner?.RelativePath ?? throw new NullReferenceException("Model not selected") ) ); @@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions builder.Connections.ImageSize = builder.Connections.LatentSize; } - var saveImage = builder.Nodes.AddNamedNode( + var previewImage = builder.Nodes.AddNamedNode( new NamedComfyNode("SaveImage") { - ClassType = "SaveImage", - Inputs = new Dictionary - { - ["filename_prefix"] = "Inference/TextToImage", - ["images"] = builder.Connections.Image - } + ClassType = "PreviewImage", + Inputs = new Dictionary { ["images"] = builder.Connections.Image } } ); - builder.Connections.OutputNodes.Add(saveImage); + builder.Connections.OutputNodes.Add(previewImage); - return saveImage.Name; + return previewImage.Name; } } diff --git a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs index 28c215d6..a090c6a2 100644 --- a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs +++ b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs @@ -13,50 +13,57 @@ public static class ImageProcessor /// public static (int rows, int columns) GetGridDimensionsFromImageCount(int count) { - if (count <= 1) return (1, 1); - if (count == 2) return (1, 2); - + if (count <= 1) + return (1, 1); + if (count == 2) + return (1, 2); + // Prefer one extra row over one extra column, // the row count will be the floor of the square root // and the column count will be floor of count / rows - var rows = (int) Math.Floor(Math.Sqrt(count)); - var columns = (int) Math.Floor((double) count / rows); + var rows = (int)Math.Floor(Math.Sqrt(count)); + var columns = (int)Math.Floor((double)count / rows); return (rows, columns); } - - public static SKImage CreateImageGrid( - IReadOnlyList images, - int spacing = 0) + + public static SKImage CreateImageGrid(IReadOnlyList images, int spacing = 0) { + if (images.Count == 0) + throw new ArgumentException("Must have at least one image"); + var (rows, columns) = GetGridDimensionsFromImageCount(images.Count); var singleWidth = images[0].Width; var singleHeight = images[0].Height; - + // Make output image using var output = new SKBitmap( - singleWidth * columns + spacing * (columns - 1), - singleHeight * rows + spacing * (rows - 1)); - + singleWidth * columns + spacing * (columns - 1), + singleHeight * rows + spacing * (rows - 1) + ); + // Draw images using var canvas = new SKCanvas(output); - - foreach (var (row, column) in - Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))) + + foreach ( + var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns)) + ) { // Stop if we have drawn all images var index = row * columns + column; - if (index >= images.Count) break; - + if (index >= images.Count) + break; + // Get image var image = images[index]; - + // Draw image var destination = new SKRect( singleWidth * column + spacing * column, singleHeight * row + spacing * row, singleWidth * column + spacing * column + image.Width, - singleHeight * row + spacing * row + image.Height); + singleHeight * row + spacing * row + image.Height + ); canvas.DrawImage(image, destination); } diff --git a/StabilityMatrix.Avalonia/Helpers/ImageSearcher.cs b/StabilityMatrix.Avalonia/Helpers/ImageSearcher.cs new file mode 100644 index 00000000..b5d280fe --- /dev/null +++ b/StabilityMatrix.Avalonia/Helpers/ImageSearcher.cs @@ -0,0 +1,95 @@ +using System; +using FuzzySharp; +using FuzzySharp.PreProcess; +using StabilityMatrix.Core.Models.Database; + +namespace StabilityMatrix.Avalonia.Helpers; + +public class ImageSearcher +{ + public int MinimumFuzzScore { get; init; } = 80; + + public ImageSearchOptions SearchOptions { get; init; } = ImageSearchOptions.All; + + public Func GetPredicate(string? searchQuery) + { + if (string.IsNullOrEmpty(searchQuery)) + { + return _ => true; + } + + return file => + { + if (file.FileName.Contains(searchQuery, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + if ( + SearchOptions.HasFlag(ImageSearchOptions.FileName) + && Fuzz.WeightedRatio(searchQuery, file.FileName, PreprocessMode.Full) + > MinimumFuzzScore + ) + { + return true; + } + + // Generation params + if (file.GenerationParameters is { } parameters) + { + if ( + SearchOptions.HasFlag(ImageSearchOptions.PositivePrompt) + && ( + parameters.PositivePrompt?.Contains( + searchQuery, + StringComparison.OrdinalIgnoreCase + ) ?? false + ) + || SearchOptions.HasFlag(ImageSearchOptions.NegativePrompt) + && ( + parameters.NegativePrompt?.Contains( + searchQuery, + StringComparison.OrdinalIgnoreCase + ) ?? false + ) + || SearchOptions.HasFlag(ImageSearchOptions.Seed) + && parameters.Seed + .ToString() + .StartsWith(searchQuery, StringComparison.OrdinalIgnoreCase) + || SearchOptions.HasFlag(ImageSearchOptions.Sampler) + && ( + parameters.Sampler?.StartsWith( + searchQuery, + StringComparison.OrdinalIgnoreCase + ) ?? false + ) + || SearchOptions.HasFlag(ImageSearchOptions.ModelName) + && ( + parameters.ModelName?.StartsWith( + searchQuery, + StringComparison.OrdinalIgnoreCase + ) ?? false + ) + ) + { + return true; + } + } + + return false; + }; + } + + [Flags] + public enum ImageSearchOptions + { + None = 0, + FileName = 1 << 0, + PositivePrompt = 1 << 1, + NegativePrompt = 1 << 2, + Seed = 1 << 3, + Sampler = 1 << 4, + ModelName = 1 << 5, + All = int.MaxValue + } +} diff --git a/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs b/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs index 086d0785..1d729cfe 100644 --- a/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs @@ -3,7 +3,6 @@ using System.IO; using System.Linq; using System.Text; using System.Text.Json; -using Avalonia; using Force.Crc32; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Core.Models; @@ -16,6 +15,17 @@ public static class PngDataHelper private static readonly byte[] Text = { 0x74, 0x45, 0x58, 0x74 }; private static readonly byte[] Iend = { 0x49, 0x45, 0x4E, 0x44 }; + public static byte[] AddMetadata( + Stream inputStream, + GenerationParameters generationParameters, + InferenceProjectDocument projectDocument + ) + { + using var ms = new MemoryStream(); + inputStream.CopyTo(ms); + return AddMetadata(ms.ToArray(), generationParameters, projectDocument); + } + public static byte[] AddMetadata( byte[] inputImage, GenerationParameters generationParameters, diff --git a/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs b/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs index 1afa66e9..414fc1a5 100644 --- a/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs +++ b/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs @@ -1,24 +1,21 @@ using System.Collections.Generic; -using System.Data.Common; -using System.Globalization; using System.IO; using System.Threading.Tasks; using StabilityMatrix.Avalonia.Models.TagCompletion; using Sylvan.Data.Csv; using Sylvan; -using Sylvan.Data; namespace StabilityMatrix.Avalonia.Helpers; public class TagCsvParser { private readonly Stream stream; - + public TagCsvParser(Stream stream) { this.stream = stream; } - + public async IAsyncEnumerable ParseAsync() { var pool = new StringPool(); @@ -27,10 +24,10 @@ public class TagCsvParser StringFactory = pool.GetString, HasHeaders = false, }; - + using var textReader = new StreamReader(stream); await using var dataReader = await CsvDataReader.CreateAsync(textReader, options); - + while (await dataReader.ReadAsync()) { var entry = new TagCsvEntry @@ -42,7 +39,7 @@ public class TagCsvParser }; yield return entry; } - + /*var dataBinderOptions = new DataBinderOptions { BindingMode = DataBindingMode.Any @@ -54,17 +51,17 @@ public class TagCsvParser public async Task> GetDictionaryAsync() { var dict = new Dictionary(); - + await foreach (var entry in ParseAsync()) { if (entry.Name is null || entry.Type is null) { continue; } - + dict.Add(entry.Name, entry); } - + return dict; } } diff --git a/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs b/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs index 21139027..83a744eb 100644 --- a/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs +++ b/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs @@ -1,5 +1,5 @@ -using System.IO; -using System.Reflection; +using System; +using System.IO; using Avalonia.Media; using AvaloniaEdit; using AvaloniaEdit.TextMate; @@ -18,13 +18,22 @@ public static class TextEditorConfigs { public static void Configure(TextEditor editor, TextEditorPreset preset) { - if (preset == TextEditorPreset.Prompt) + switch (preset) { - ConfigForPrompt(editor); + case TextEditorPreset.Prompt: + ConfigForPrompt(editor); + break; + case TextEditorPreset.Console: + ConfigForConsole(editor); + break; + case TextEditorPreset.None: + break; + default: + throw new ArgumentOutOfRangeException(nameof(preset), preset, null); } } - public static void ConfigForPrompt(TextEditor editor) + private static void ConfigForPrompt(TextEditor editor) { const ThemeName themeName = ThemeName.DimmedMonokai; var registryOptions = new RegistryOptions(themeName); @@ -58,6 +67,25 @@ public static class TextEditorConfigs installation.SetTheme(theme); } + private static void ConfigForConsole(TextEditor editor) + { + var registryOptions = new RegistryOptions(ThemeName.DarkPlus); + + // Config hyperlinks + editor.TextArea.Options.EnableHyperlinks = true; + editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false; + editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral; + + var textMate = editor.InstallTextMate(registryOptions); + var scope = registryOptions.GetScopeByLanguageId("log"); + + if (scope is null) + throw new InvalidOperationException("Scope is null"); + + textMate.SetGrammar(scope); + textMate.SetTheme(registryOptions.LoadTheme(ThemeName.DarkPlus)); + } + private static IRawTheme GetThemeFromStream(Stream stream) { using var reader = new StreamReader(stream); diff --git a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs index 07a38fb9..04cb8ffd 100644 --- a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs @@ -126,7 +126,11 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper } } - public async Task RunGit(string? workingDirectory = null, params string[] args) + public async Task RunGit( + string? workingDirectory = null, + Action? onProcessOutput = null, + params string[] args + ) { var command = args.Length == 0 ? "git" : "git " + string.Join(" ", args.Select(ProcessRunner.Quote)); @@ -229,6 +233,13 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper throw new NotImplementedException(); } + [UnsupportedOSPlatform("Linux")] + [UnsupportedOSPlatform("macOS")] + public Task InstallTkinterIfNecessary(IProgress? progress = null) + { + throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("Linux")] [UnsupportedOSPlatform("macOS")] public Task InstallVcRedistIfNecessary(IProgress? progress = null) diff --git a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs index e5ee642f..8df65b94 100644 --- a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs +++ b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs @@ -17,70 +17,87 @@ namespace StabilityMatrix.Avalonia.Helpers; public class WindowsPrerequisiteHelper : IPrerequisiteHelper { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - + private readonly IGitHubClient gitHubClient; private readonly IDownloadService downloadService; private readonly ISettingsManager settingsManager; - + private const string VcRedistDownloadUrl = "https://aka.ms/vs/16/release/vc_redist.x64.exe"; - + private const string TkinterDownloadUrl = + "https://cdn.lykos.ai/tkinter-cpython-embedded-3.10.11-win-x64.zip"; + private string HomeDir => settingsManager.LibraryDir; - + private string VcRedistDownloadPath => Path.Combine(HomeDir, "vcredist.x64.exe"); private string AssetsDir => Path.Combine(HomeDir, "Assets"); private string SevenZipPath => Path.Combine(AssetsDir, "7za.exe"); - + private string PythonDownloadPath => Path.Combine(AssetsDir, "python-3.10.11-embed-amd64.zip"); private string PythonDir => Path.Combine(AssetsDir, "Python310"); private string PythonDllPath => Path.Combine(PythonDir, "python310.dll"); private string PythonLibraryZipPath => Path.Combine(PythonDir, "python310.zip"); private string GetPipPath => Path.Combine(PythonDir, "get-pip.pyc"); + // Temporary directory to extract venv to during python install private string VenvTempDir => Path.Combine(PythonDir, "venv"); - + private string PortableGitInstallDir => Path.Combine(HomeDir, "PortableGit"); private string PortableGitDownloadPath => Path.Combine(HomeDir, "PortableGit.7z.exe"); private string GitExePath => Path.Combine(PortableGitInstallDir, "bin", "git.exe"); + private string TkinterZipPath => Path.Combine(AssetsDir, "tkinter.zip"); + private string TkinterExtractPath => PythonDir; + private string TkinterExistsPath => Path.Combine(PythonDir, "tkinter"); public string GitBinPath => Path.Combine(PortableGitInstallDir, "bin"); - + public bool IsPythonInstalled => File.Exists(PythonDllPath); public WindowsPrerequisiteHelper( IGitHubClient gitHubClient, - IDownloadService downloadService, - ISettingsManager settingsManager) + IDownloadService downloadService, + ISettingsManager settingsManager + ) { this.gitHubClient = gitHubClient; this.downloadService = downloadService; this.settingsManager = settingsManager; } - public async Task RunGit(string? workingDirectory = null, params string[] args) + public async Task RunGit( + string? workingDirectory = null, + Action? onProcessOutput = null, + params string[] args + ) { - var process = ProcessRunner.StartAnsiProcess(GitExePath, args, + var process = ProcessRunner.StartAnsiProcess( + GitExePath, + args, workingDirectory: workingDirectory, environmentVariables: new Dictionary { - {"PATH", Compat.GetEnvPathWithExtensions(GitBinPath)} - }); - + { "PATH", Compat.GetEnvPathWithExtensions(GitBinPath) } + }, + outputDataReceived: onProcessOutput + ); + await ProcessRunner.WaitForExitConditionAsync(process); } public async Task GetGitOutput(string? workingDirectory = null, params string[] args) { var process = await ProcessRunner.GetProcessOutputAsync( - GitExePath, string.Join(" ", args), + GitExePath, + string.Join(" ", args), workingDirectory: workingDirectory, environmentVariables: new Dictionary { - {"PATH", Compat.GetEnvPathWithExtensions(GitBinPath)} - }); - + { "PATH", Compat.GetEnvPathWithExtensions(GitBinPath) } + } + ); + return process; } - + public async Task InstallAllIfNecessary(IProgress? progress = null) { await InstallVcRedistIfNecessary(progress); @@ -97,16 +114,20 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper (Assets.SevenZipExecutable, AssetsDir), (Assets.SevenZipLicense, AssetsDir), }; - - progress?.Report(new ProgressReport(0, message: "Unpacking resources", isIndeterminate: true)); - + + progress?.Report( + new ProgressReport(0, message: "Unpacking resources", isIndeterminate: true) + ); + Directory.CreateDirectory(AssetsDir); foreach (var (asset, extractDir) in assets) { await asset.ExtractToDir(extractDir); } - - progress?.Report(new ProgressReport(1, message: "Unpacking resources", isIndeterminate: false)); + + progress?.Report( + new ProgressReport(1, message: "Unpacking resources", isIndeterminate: false) + ); } public async Task InstallPythonIfNecessary(IProgress? progress = null) @@ -120,7 +141,7 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper Logger.Info("Python not found at {PythonDllPath}, downloading...", PythonDllPath); Directory.CreateDirectory(AssetsDir); - + // Delete existing python zip if it exists if (File.Exists(PythonLibraryZipPath)) { @@ -130,44 +151,45 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper var remote = Assets.PythonDownloadUrl; var url = remote.Url.ToString(); Logger.Info($"Downloading Python from {url} to {PythonLibraryZipPath}"); - + // Cleanup to remove zip if download fails try { // Download python zip await downloadService.DownloadToFileAsync(url, PythonDownloadPath, progress: progress); - + // Verify python hash var downloadHash = await FileHash.GetSha256Async(PythonDownloadPath, progress); if (downloadHash != remote.HashSha256) { var fileExists = File.Exists(PythonDownloadPath); var fileSize = new FileInfo(PythonDownloadPath).Length; - var msg = $"Python download hash mismatch: {downloadHash} != {remote.HashSha256} " + - $"(file exists: {fileExists}, size: {fileSize})"; + var msg = + $"Python download hash mismatch: {downloadHash} != {remote.HashSha256} " + + $"(file exists: {fileExists}, size: {fileSize})"; throw new Exception(msg); } - + progress?.Report(new ProgressReport(progress: 1f, message: "Python download complete")); - + progress?.Report(new ProgressReport(-1, "Installing Python...", isIndeterminate: true)); - + // We also need 7z if it's not already unpacked if (!File.Exists(SevenZipPath)) { await Assets.SevenZipExecutable.ExtractToDir(AssetsDir); await Assets.SevenZipLicense.ExtractToDir(AssetsDir); } - + // Delete existing python dir if (Directory.Exists(PythonDir)) { Directory.Delete(PythonDir, true); } - + // Unzip python await ArchiveHelper.Extract7Z(PythonDownloadPath, PythonDir); - + try { // Extract embedded venv folder @@ -185,7 +207,7 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper await resource.ExtractTo(path); } // Add venv to python's library zip - + await ArchiveHelper.AddToArchive7Z(PythonLibraryZipPath, VenvTempDir); } finally @@ -196,16 +218,19 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper Directory.Delete(VenvTempDir, true); } } - + // Extract get-pip.pyc await Assets.PyScriptGetPip.ExtractToDir(PythonDir); - + // We need to uncomment the #import site line in python310._pth for pip to work var pythonPthPath = Path.Combine(PythonDir, "python310._pth"); var pythonPthContent = await File.ReadAllTextAsync(pythonPthPath); pythonPthContent = pythonPthContent.Replace("#import site", "import site"); await File.WriteAllTextAsync(pythonPthPath, pythonPthContent); - + + // Install TKinter + await InstallTkinterIfNecessary(progress); + progress?.Report(new ProgressReport(1f, "Python install complete")); } finally @@ -218,6 +243,39 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper } } + [SupportedOSPlatform("windows")] + public async Task InstallTkinterIfNecessary(IProgress? progress = null) + { + if (!Directory.Exists(TkinterExistsPath)) + { + Logger.Info("Downloading Tkinter"); + await downloadService.DownloadToFileAsync( + TkinterDownloadUrl, + TkinterZipPath, + progress: progress + ); + progress?.Report( + new ProgressReport( + progress: 1f, + message: "Tkinter download complete", + type: ProgressType.Download + ) + ); + + await ArchiveHelper.Extract(TkinterZipPath, TkinterExtractPath, progress); + + File.Delete(TkinterZipPath); + } + + progress?.Report( + new ProgressReport( + progress: 1f, + message: "Tkinter install complete", + type: ProgressType.Generic + ) + ); + } + public async Task InstallGitIfNecessary(IProgress? progress = null) { if (File.Exists(GitExePath)) @@ -225,7 +283,7 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper Logger.Debug("Git already installed at {GitExePath}", GitExePath); return; } - + Logger.Info("Git not found at {GitExePath}, downloading...", GitExePath); var portableGitUrl = @@ -233,7 +291,11 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper if (!File.Exists(PortableGitDownloadPath)) { - await downloadService.DownloadToFileAsync(portableGitUrl, PortableGitDownloadPath, progress: progress); + await downloadService.DownloadToFileAsync( + portableGitUrl, + PortableGitDownloadPath, + progress: progress + ); progress?.Report(new ProgressReport(progress: 1f, message: "Git download complete")); } @@ -245,7 +307,9 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper { var registry = Registry.LocalMachine; var key = registry.OpenSubKey( - @"SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\X64", false); + @"SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\X64", + false + ); if (key != null) { var buildId = Convert.ToUInt32(key.GetValue("Bld")); @@ -254,20 +318,44 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper return; } } - + Logger.Info("Downloading VC Redist"); - await downloadService.DownloadToFileAsync(VcRedistDownloadUrl, VcRedistDownloadPath, progress: progress); - progress?.Report(new ProgressReport(progress: 1f, message: "Visual C++ download complete", - type: ProgressType.Download)); - + await downloadService.DownloadToFileAsync( + VcRedistDownloadUrl, + VcRedistDownloadPath, + progress: progress + ); + progress?.Report( + new ProgressReport( + progress: 1f, + message: "Visual C++ download complete", + type: ProgressType.Download + ) + ); + Logger.Info("Installing VC Redist"); - progress?.Report(new ProgressReport(progress: 0.5f, isIndeterminate: true, type: ProgressType.Generic, message: "Installing prerequisites...")); - var process = ProcessRunner.StartAnsiProcess(VcRedistDownloadPath, "/install /quiet /norestart"); + progress?.Report( + new ProgressReport( + progress: 0.5f, + isIndeterminate: true, + type: ProgressType.Generic, + message: "Installing prerequisites..." + ) + ); + var process = ProcessRunner.StartAnsiProcess( + VcRedistDownloadPath, + "/install /quiet /norestart" + ); await process.WaitForExitAsync(); - progress?.Report(new ProgressReport(progress: 1f, message: "Visual C++ install complete", - type: ProgressType.Generic)); - + progress?.Report( + new ProgressReport( + progress: 1f, + message: "Visual C++ install complete", + type: ProgressType.Generic + ) + ); + File.Delete(VcRedistDownloadPath); } @@ -286,5 +374,4 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper File.Delete(PortableGitDownloadPath); } - } diff --git a/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs b/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs index 73fc62e0..1f6f11e8 100644 --- a/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs +++ b/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs @@ -113,6 +113,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Clear Selection. + /// + public static string Action_ClearSelection { + get { + return ResourceManager.GetString("Action_ClearSelection", resourceCulture); + } + } + /// /// Looks up a localized string similar to Close. /// @@ -131,6 +140,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Consolidate. + /// + public static string Action_Consolidate { + get { + return ResourceManager.GetString("Action_Consolidate", resourceCulture); + } + } + /// /// Looks up a localized string similar to Continue. /// @@ -140,6 +158,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Copy. + /// + public static string Action_Copy { + get { + return ResourceManager.GetString("Action_Copy", resourceCulture); + } + } + /// /// Looks up a localized string similar to Delete. /// @@ -149,6 +176,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Downgrade. + /// + public static string Action_Downgrade { + get { + return ResourceManager.GetString("Action_Downgrade", resourceCulture); + } + } + /// /// Looks up a localized string similar to Edit. /// @@ -248,6 +284,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Open in Image Viewer. + /// + public static string Action_OpenInViewer { + get { + return ResourceManager.GetString("Action_OpenInViewer", resourceCulture); + } + } + /// /// Looks up a localized string similar to Open on CivitAI. /// @@ -284,6 +329,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Refresh. + /// + public static string Action_Refresh { + get { + return ResourceManager.GetString("Action_Refresh", resourceCulture); + } + } + /// /// Looks up a localized string similar to Relaunch. /// @@ -383,6 +437,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Select All. + /// + public static string Action_SelectAll { + get { + return ResourceManager.GetString("Action_SelectAll", resourceCulture); + } + } + /// /// Looks up a localized string similar to Select Directory. /// @@ -410,6 +473,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Send to Inference. + /// + public static string Action_SendToInference { + get { + return ResourceManager.GetString("Action_SendToInference", resourceCulture); + } + } + /// /// Looks up a localized string similar to Show in Explorer. /// @@ -446,6 +518,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Upgrade. + /// + public static string Action_Upgrade { + get { + return ResourceManager.GetString("Action_Upgrade", resourceCulture); + } + } + /// /// Looks up a localized string similar to Yes. /// @@ -509,6 +590,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Are you sure?. + /// + public static string Label_AreYouSure { + get { + return ResourceManager.GetString("Label_AreYouSure", resourceCulture); + } + } + /// /// Looks up a localized string similar to Automatically scroll to end of console output. /// @@ -662,6 +752,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to This will move all generated images from the selected packages to the Consolidated directory of the shared outputs folder. This action cannot be undone.. + /// + public static string Label_ConsolidateExplanation { + get { + return ResourceManager.GetString("Label_ConsolidateExplanation", resourceCulture); + } + } + /// /// Looks up a localized string similar to Current directory:. /// @@ -888,7 +987,16 @@ namespace StabilityMatrix.Avalonia.Languages { } /// - /// Looks up a localized string similar to Import as Connected. + /// Looks up a localized string similar to Image to Image. + /// + public static string Label_ImageToImage { + get { + return ResourceManager.GetString("Label_ImageToImage", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Import with Metadata. /// public static string Label_ImportAsConnected { get { @@ -932,6 +1040,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Inpainting. + /// + public static string Label_Inpainting { + get { + return ResourceManager.GetString("Label_Inpainting", resourceCulture); + } + } + /// /// Looks up a localized string similar to Input. /// @@ -1148,6 +1265,24 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to {0} images selected. + /// + public static string Label_NumImagesSelected { + get { + return ResourceManager.GetString("Label_NumImagesSelected", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to 1 image selected. + /// + public static string Label_OneImageSelected { + get { + return ResourceManager.GetString("Label_OneImageSelected", resourceCulture); + } + } + /// /// Looks up a localized string similar to Only available on Windows. /// @@ -1157,6 +1292,33 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Output Folder. + /// + public static string Label_OutputFolder { + get { + return ResourceManager.GetString("Label_OutputFolder", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Output Browser. + /// + public static string Label_OutputsPageTitle { + get { + return ResourceManager.GetString("Label_OutputsPageTitle", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Output Type. + /// + public static string Label_OutputType { + get { + return ResourceManager.GetString("Label_OutputType", resourceCulture); + } + } + /// /// Looks up a localized string similar to Package Environment. /// @@ -1247,6 +1409,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Python Packages. + /// + public static string Label_PythonPackages { + get { + return ResourceManager.GetString("Label_PythonPackages", resourceCulture); + } + } + /// /// Looks up a localized string similar to Python Version Info. /// @@ -1481,6 +1652,15 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Text to Image. + /// + public static string Label_TextToImage { + get { + return ResourceManager.GetString("Label_TextToImage", resourceCulture); + } + } + /// /// Looks up a localized string similar to Theme. /// @@ -1526,6 +1706,24 @@ namespace StabilityMatrix.Avalonia.Languages { } } + /// + /// Looks up a localized string similar to Upscale. + /// + public static string Label_Upscale { + get { + return ResourceManager.GetString("Label_Upscale", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Output Sharing. + /// + public static string Label_UseSharedOutputFolder { + get { + return ResourceManager.GetString("Label_UseSharedOutputFolder", resourceCulture); + } + } + /// /// Looks up a localized string similar to VAE. /// diff --git a/StabilityMatrix.Avalonia/Languages/Resources.resx b/StabilityMatrix.Avalonia/Languages/Resources.resx index 7575d1d8..f16aba39 100644 --- a/StabilityMatrix.Avalonia/Languages/Resources.resx +++ b/StabilityMatrix.Avalonia/Languages/Resources.resx @@ -379,7 +379,7 @@ Drop file here to import - Import as Connected + Import with Metadata Search for connected metadata on new local imports @@ -678,7 +678,73 @@ Restore Default Layout + + Output Sharing + Batch Index - \ No newline at end of file + + Copy + + + Open in Image Viewer + + + {0} images selected + + + Output Folder + + + Output Type + + + Clear Selection + + + Select All + + + Send to Inference + + + Text to Image + + + Image to Image + + + Inpainting + + + Upscale + + + Output Browser + + + 1 image selected + + + Python Packages + + + Consolidate + + + Are you sure? + + + This will move all generated images from the selected packages to the Consolidated directory of the shared outputs folder. This action cannot be undone. + + + Refresh + + + Upgrade + + + Downgrade + + diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs new file mode 100644 index 00000000..5017ab38 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormat +{ + public string Template { get; } + + public string Prefix { get; set; } = ""; + + public string Postfix { get; set; } = ""; + + public IReadOnlyList Parts { get; } + + private FileNameFormat(string template, IReadOnlyList parts) + { + Template = template; + Parts = parts; + } + + public FileNameFormat WithBatchPostFix(int current, int total) + { + return this with { Postfix = Postfix + $" ({current}-{total})" }; + } + + public FileNameFormat WithGridPrefix() + { + return this with { Prefix = Prefix + "Grid_" }; + } + + public string GetFileName() + { + return Prefix + + string.Join( + "", + Parts.Select( + part => part.Match(constant => constant, substitution => substitution.Invoke()) + ) + ) + + Postfix; + } + + public static FileNameFormat Parse(string template, FileNameFormatProvider provider) + { + var parts = provider.GetParts(template).ToImmutableArray(); + return new FileNameFormat(template, parts); + } + + public static bool TryParse( + string template, + FileNameFormatProvider provider, + [NotNullWhen(true)] out FileNameFormat? format + ) + { + try + { + format = Parse(template, provider); + return true; + } + catch (ArgumentException) + { + format = null; + return false; + } + } + + public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}"; +} diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs new file mode 100644 index 00000000..3b17284b --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs @@ -0,0 +1,7 @@ +using System; +using OneOf; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +[GenerateOneOf] +public partial class FileNameFormatPart : OneOfBase> { } diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs new file mode 100644 index 00000000..ff6905fd --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -0,0 +1,197 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.Contracts; +using System.IO; +using System.Linq; +using System.Text.RegularExpressions; +using Avalonia.Data; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Avalonia.Models.Inference; + +public partial class FileNameFormatProvider +{ + public GenerationParameters? GenerationParameters { get; init; } + + public InferenceProjectType? ProjectType { get; init; } + + public string? ProjectName { get; init; } + + private Dictionary>? _substitutions; + + public Dictionary> Substitutions => + _substitutions ??= new Dictionary> + { + { "seed", () => GenerationParameters?.Seed.ToString() }, + { "prompt", () => GenerationParameters?.PositivePrompt }, + { "negative_prompt", () => GenerationParameters?.NegativePrompt }, + { + "model_name", + () => Path.GetFileNameWithoutExtension(GenerationParameters?.ModelName) + }, + { "model_hash", () => GenerationParameters?.ModelHash }, + { "width", () => GenerationParameters?.Width.ToString() }, + { "height", () => GenerationParameters?.Height.ToString() }, + { "project_type", () => ProjectType?.GetStringValue() }, + { "project_name", () => ProjectName }, + { "date", () => DateTime.Now.ToString("yyyy-MM-dd") }, + { "time", () => DateTime.Now.ToString("HH-mm-ss") } + }; + + /// + /// Validate a format string + /// + /// Format string + /// Thrown if the format string contains an unknown variable + [Pure] + public ValidationResult Validate(string format) + { + var regex = BracketRegex(); + var matches = regex.Matches(format); + var variables = matches.Select(m => m.Groups[1].Value); + + foreach (var variableText in variables) + { + try + { + var (variable, _) = ExtractVariableAndSlice(variableText); + + if (!Substitutions.ContainsKey(variable)) + { + return new ValidationResult($"Unknown variable '{variable}'"); + } + } + catch (Exception e) + { + return new ValidationResult($"Invalid variable '{variableText}': {e.Message}"); + } + } + + return ValidationResult.Success!; + } + + public IEnumerable GetParts(string template) + { + var regex = BracketRegex(); + var matches = regex.Matches(template); + + var parts = new List(); + + // Loop through all parts of the string, including matches and non-matches + var currentIndex = 0; + + foreach (var result in matches.Cast()) + { + // If the match is not at the start of the string, add a constant part + if (result.Index != currentIndex) + { + var constant = template[currentIndex..result.Index]; + parts.Add(constant); + + currentIndex += constant.Length; + } + + // Now we're at start of the current match, add the variable part + var (variable, slice) = ExtractVariableAndSlice(result.Groups[1].Value); + var substitution = Substitutions[variable]; + + // Slice string if necessary + if (slice is not null) + { + parts.Add( + (FileNameFormatPart)( + () => + { + var value = substitution(); + if (value is null) + return null; + + if (slice.End is null) + { + value = value[(slice.Start ?? 0)..]; + } + else + { + var length = + Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0); + value = value.Substring(slice.Start ?? 0, length); + } + + return value; + } + ) + ); + } + else + { + parts.Add(substitution); + } + + currentIndex += result.Length; + } + + // Add remaining as constant + if (currentIndex != template.Length) + { + var constant = template[currentIndex..]; + parts.Add(constant); + } + + return parts; + } + + /// + /// Return a sample provider for UI preview + /// + public static FileNameFormatProvider GetSample() + { + return new FileNameFormatProvider + { + GenerationParameters = GenerationParameters.GetSample(), + ProjectType = InferenceProjectType.TextToImage, + ProjectName = "Sample Project" + }; + } + + /// + /// Extract variable and index from a combined string + /// + private static (string Variable, Slice? Slice) ExtractVariableAndSlice(string combined) + { + if (IndexRegex().Matches(combined).FirstOrDefault() is not { Success: true } match) + { + return (combined, null); + } + + // Variable is everything before the match + var variable = combined[..match.Groups[0].Index]; + + var start = match.Groups["start"].Value; + var end = match.Groups["end"].Value; + var step = match.Groups["step"].Value; + + var slice = new Slice( + string.IsNullOrEmpty(start) ? null : int.Parse(start), + string.IsNullOrEmpty(end) ? null : int.Parse(end), + string.IsNullOrEmpty(step) ? null : int.Parse(step) + ); + + return (variable, slice); + } + + /// + /// Regex for matching contents within a curly brace. + /// + [GeneratedRegex(@"\{([a-z_:\d\[\]]+)\}")] + private static partial Regex BracketRegex(); + + /// + /// Regex for matching a Python-like array index. + /// + [GeneratedRegex(@"\[(?:(?-?\d+)?)\:(?:(?-?\d+)?)?(?:\:(?-?\d+))?\]")] + private static partial Regex IndexRegex(); + + private record Slice(int? Start, int? End, int? Step); +} diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs new file mode 100644 index 00000000..a453b3bc --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs @@ -0,0 +1,8 @@ +namespace StabilityMatrix.Avalonia.Models.Inference; + +public record FileNameFormatVar +{ + public required string Variable { get; init; } + + public string? Example { get; init; } +} diff --git a/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs b/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs index c4c76fa4..4c18bd81 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.ComponentModel.DataAnnotations; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; diff --git a/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs b/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs index 5c782de4..e361207b 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs @@ -1,5 +1,4 @@ using System.Text.Json.Serialization; -using StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.Models.Inference; diff --git a/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs b/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs index f736fdd8..689fe144 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs @@ -1,5 +1,4 @@ -using System.Text.Json.Nodes; -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; namespace StabilityMatrix.Avalonia.Models.Inference; diff --git a/StabilityMatrix.Avalonia/Models/PackageOutputCategory.cs b/StabilityMatrix.Avalonia/Models/PackageOutputCategory.cs new file mode 100644 index 00000000..2318d0f8 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/PackageOutputCategory.cs @@ -0,0 +1,7 @@ +namespace StabilityMatrix.Avalonia.Models; + +public class PackageOutputCategory +{ + public required string Name { get; set; } + public required string Path { get; set; } +} diff --git a/StabilityMatrix.Avalonia/Models/SharedState.cs b/StabilityMatrix.Avalonia/Models/SharedState.cs index d2dd8fe1..17fd3288 100644 --- a/StabilityMatrix.Avalonia/Models/SharedState.cs +++ b/StabilityMatrix.Avalonia/Models/SharedState.cs @@ -1,14 +1,17 @@ using CommunityToolkit.Mvvm.ComponentModel; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Models; /// /// Singleton DI service for observable shared UI state. /// +[Singleton] public partial class SharedState : ObservableObject { /// /// Whether debug mode enabled from settings page version tap. /// - [ObservableProperty] private bool isDebugMode; + [ObservableProperty] + private bool isDebugMode; } diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs index 1b4395a4..5b03207d 100644 --- a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Text.RegularExpressions; @@ -16,6 +15,7 @@ using NLog; using StabilityMatrix.Avalonia.Controls.CodeCompletion; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; @@ -26,6 +26,7 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.Models.TagCompletion; +[Singleton(typeof(ICompletionProvider))] public partial class CompletionProvider : ICompletionProvider { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs index c93df5fa..fea2490c 100644 --- a/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; -using AvaloniaEdit.Document; -using StabilityMatrix.Core.Models.Tokens; +using StabilityMatrix.Core.Models.Tokens; namespace StabilityMatrix.Avalonia.Models.TagCompletion; diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs index 74b22d44..2181e5df 100644 --- a/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs +++ b/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs @@ -1,10 +1,12 @@ using System.Diagnostics.CodeAnalysis; using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Core.Attributes; using TextMateSharp.Grammars; using TextMateSharp.Registry; namespace StabilityMatrix.Avalonia.Models.TagCompletion; +[Singleton(typeof(ITokenizerProvider))] public class TokenizerProvider : ITokenizerProvider { private readonly Registry registry = new(new RegistryOptions(ThemeName.DarkPlus)); diff --git a/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs b/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs index 3caf0b86..50416643 100644 --- a/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs +++ b/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs @@ -3,5 +3,6 @@ public enum TextEditorPreset { None, - Prompt + Prompt, + Console } diff --git a/StabilityMatrix.Avalonia/Program.cs b/StabilityMatrix.Avalonia/Program.cs index 40c95d08..899f9fc7 100644 --- a/StabilityMatrix.Avalonia/Program.cs +++ b/StabilityMatrix.Avalonia/Program.cs @@ -7,6 +7,7 @@ using System.Reflection; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; +using AsyncAwaitBestPractices; using AsyncImageLoader; using Avalonia; using Avalonia.Controls; @@ -232,10 +233,16 @@ public class Program if (e.ExceptionObject is not Exception ex) return; - Logger.Fatal(ex, "Unhandled {Type}: {Message}", ex.GetType().Name, ex.Message); + // Exception automatically logged by Sentry if enabled if (SentrySdk.IsEnabled) { + ex.SetSentryMechanism("AppDomain.UnhandledException", handled: false); SentrySdk.CaptureException(ex); + SentrySdk.FlushAsync().SafeFireAndForget(); + } + else + { + Logger.Fatal(ex, "Unhandled {Type}: {Message}", ex.GetType().Name, ex.Message); } if ( @@ -290,6 +297,10 @@ public class Program [DoesNotReturn] private static void ExitWithException(Exception exception) { + if (SentrySdk.IsEnabled) + { + SentrySdk.Flush(); + } App.Shutdown(1); Dispatcher.UIThread.InvokeShutdown(); Environment.Exit(Marshal.GetHRForException(exception)); diff --git a/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs index bc6afaff..f40990fa 100644 --- a/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Threading; diff --git a/StabilityMatrix.Avalonia/Services/INotificationService.cs b/StabilityMatrix.Avalonia/Services/INotificationService.cs index 77a4d9d2..d3cc7e77 100644 --- a/StabilityMatrix.Avalonia/Services/INotificationService.cs +++ b/StabilityMatrix.Avalonia/Services/INotificationService.cs @@ -2,6 +2,8 @@ using System.Threading.Tasks; using Avalonia; using Avalonia.Controls.Notifications; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.Services; @@ -11,7 +13,8 @@ public interface INotificationService public void Initialize( Visual? visual, NotificationPosition position = NotificationPosition.BottomRight, - int maxItems = 3); + int maxItems = 3 + ); public void Show(INotification notification); @@ -26,7 +29,8 @@ public interface INotificationService Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error); + NotificationType appearance = NotificationType.Error + ); /// /// Attempt to run the given void task, showing a generic error notification if it fails. @@ -40,16 +44,18 @@ public interface INotificationService Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error); + NotificationType appearance = NotificationType.Error + ); /// /// Show a notification with the given parameters. /// void Show( - string title, + string title, string message, NotificationType appearance = NotificationType.Information, - TimeSpan? expiration = null); + TimeSpan? expiration = null + ); /// /// Show a notification that will not auto-dismiss. @@ -60,5 +66,15 @@ public interface INotificationService void ShowPersistent( string title, string message, - NotificationType appearance = NotificationType.Information); + NotificationType appearance = NotificationType.Information + ); + + /// + /// Show a notification for a that will not auto-dismiss. + /// + void ShowPersistent( + AppException exception, + NotificationType appearance = NotificationType.Error, + LogLevel logLevel = LogLevel.Warning + ); } diff --git a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs index 3e71a877..8254e467 100644 --- a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs +++ b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs @@ -13,6 +13,7 @@ using SkiaSharp; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Inference; using StabilityMatrix.Core.Models; @@ -27,6 +28,7 @@ namespace StabilityMatrix.Avalonia.Services; /// Manager for the current inference client /// Has observable shared properties for shared info like model names /// +[Singleton(typeof(IInferenceClientManager))] public partial class InferenceClientManager : ObservableObject, IInferenceClientManager { private readonly ILogger logger; @@ -345,6 +347,44 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken); } + private async Task MigrateLinksIfNeeded(PackagePair packagePair) + { + if (packagePair.InstalledPackage.FullPath is not { } packagePath) + { + throw new ArgumentException("Package path is null", nameof(packagePair)); + } + + var inferenceDir = settingsManager.ImagesInferenceDirectory; + inferenceDir.Create(); + + // For locally installed packages only + // Delete ./output/Inference + + var legacyInferenceLinkDir = new DirectoryPath( + packagePair.InstalledPackage.FullPath + ).JoinDir("output", "Inference"); + + if (legacyInferenceLinkDir.Exists) + { + logger.LogInformation( + "Deleting legacy inference link at {LegacyDir}", + legacyInferenceLinkDir + ); + + if (legacyInferenceLinkDir.IsSymbolicLink) + { + await legacyInferenceLinkDir.DeleteAsync(false); + } + else + { + logger.LogWarning( + "Legacy inference link at {LegacyDir} is not a symbolic link, skipping", + legacyInferenceLinkDir + ); + } + } + } + /// public async Task ConnectAsync( PackagePair packagePair, @@ -367,11 +407,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient logger.LogError(ex, "Error setting up completion provider"); }); - // Setup image folder links - await comfyPackage.SetupInferenceOutputFolderLinks( - packagePair.InstalledPackage.FullPath - ?? throw new InvalidOperationException("Package does not have a Path") - ); + await MigrateLinksIfNeeded(packagePair); // Get user defined host and port var host = packagePair.InstalledPackage.GetLaunchArgsHost(); diff --git a/StabilityMatrix.Avalonia/Services/NavigationService.cs b/StabilityMatrix.Avalonia/Services/NavigationService.cs index 4cc47881..b9fb1952 100644 --- a/StabilityMatrix.Avalonia/Services/NavigationService.cs +++ b/StabilityMatrix.Avalonia/Services/NavigationService.cs @@ -1,6 +1,4 @@ using System; -using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using FluentAvalonia.UI.Controls; using FluentAvalonia.UI.Media.Animation; @@ -8,10 +6,12 @@ using FluentAvalonia.UI.Navigation; using StabilityMatrix.Avalonia.Animations; using StabilityMatrix.Avalonia.ViewModels; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.Services; +[Singleton(typeof(INavigationService))] public class NavigationService : INavigationService { private Frame? _frame; diff --git a/StabilityMatrix.Avalonia/Services/NotificationService.cs b/StabilityMatrix.Avalonia/Services/NotificationService.cs index 1bb664b9..98611f8d 100644 --- a/StabilityMatrix.Avalonia/Services/NotificationService.cs +++ b/StabilityMatrix.Avalonia/Services/NotificationService.cs @@ -3,20 +3,32 @@ using System.Threading.Tasks; using Avalonia; using Avalonia.Controls; using Avalonia.Controls.Notifications; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.Services; +[Singleton(typeof(INotificationService))] public class NotificationService : INotificationService { + private readonly ILogger logger; private WindowNotificationManager? notificationManager; - + + public NotificationService(ILogger logger) + { + this.logger = logger; + } + public void Initialize( - Visual? visual, + Visual? visual, NotificationPosition position = NotificationPosition.BottomRight, - int maxItems = 4) + int maxItems = 4 + ) { - if (notificationManager is not null) return; + if (notificationManager is not null) + return; notificationManager = new WindowNotificationManager(TopLevel.GetTopLevel(visual)) { Position = position, @@ -30,28 +42,44 @@ public class NotificationService : INotificationService } public void Show( - string title, + string title, string message, NotificationType appearance = NotificationType.Information, - TimeSpan? expiration = null) + TimeSpan? expiration = null + ) { Show(new Notification(title, message, appearance, expiration)); } public void ShowPersistent( - string title, + string title, string message, - NotificationType appearance = NotificationType.Information) + NotificationType appearance = NotificationType.Information + ) { Show(new Notification(title, message, appearance, TimeSpan.Zero)); } - + + /// + public void ShowPersistent( + AppException exception, + NotificationType appearance = NotificationType.Warning, + LogLevel logLevel = LogLevel.Warning + ) + { + // Log exception + logger.Log(logLevel, exception, "{Message}", exception.Message); + + Show(new Notification(exception.Message, exception.Details, appearance, TimeSpan.Zero)); + } + /// public async Task> TryAsync( Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error) + NotificationType appearance = NotificationType.Error + ) { try { @@ -63,13 +91,14 @@ public class NotificationService : INotificationService return TaskResult.FromException(e); } } - + /// public async Task> TryAsync( Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error) + NotificationType appearance = NotificationType.Error + ) { try { diff --git a/StabilityMatrix.Avalonia/Services/ServiceManager.cs b/StabilityMatrix.Avalonia/Services/ServiceManager.cs index 26d604f4..4b21c294 100644 --- a/StabilityMatrix.Avalonia/Services/ServiceManager.cs +++ b/StabilityMatrix.Avalonia/Services/ServiceManager.cs @@ -256,4 +256,19 @@ public class ServiceManager return new BetterContentDialog { Content = view }; } + + public void Register(Type type, Func providerFunc) + { + lock (providers) + { + if (instances.ContainsKey(type) || providers.ContainsKey(type)) + { + throw new ArgumentException( + $"Service of type {type} is already registered for {typeof(T)}" + ); + } + + providers[type] = providerFunc; + } + } } diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj index 648330dd..4ceea5b9 100644 --- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj +++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj @@ -8,7 +8,7 @@ app.manifest true ./Assets/Icon.ico - 2.5.5-dev.1 + 2.6.0-dev.1 $(Version) true true @@ -23,46 +23,48 @@ - - - + + + - + - + - + - - + + - + - + - - - + + + + - + - - + + + - + diff --git a/StabilityMatrix.Avalonia/Styles/ThemeColors.cs b/StabilityMatrix.Avalonia/Styles/ThemeColors.cs index 5df260c2..c59b9d53 100644 --- a/StabilityMatrix.Avalonia/Styles/ThemeColors.cs +++ b/StabilityMatrix.Avalonia/Styles/ThemeColors.cs @@ -1,5 +1,4 @@ -using Avalonia; -using Avalonia.Media; +using Avalonia.Media; namespace StabilityMatrix.Avalonia.Styles; diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs index b905bba6..a4ed0309 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs @@ -1,10 +1,14 @@ using System; +using CommunityToolkit.Mvvm.ComponentModel; using FluentAvalonia.UI.Controls; namespace StabilityMatrix.Avalonia.ViewModels.Base; -public class ContentDialogProgressViewModelBase : ConsoleProgressViewModel +public partial class ContentDialogProgressViewModelBase : ConsoleProgressViewModel { + [ObservableProperty] + private bool hideCloseButton; + public event EventHandler? PrimaryButtonClick; public event EventHandler? SecondaryButtonClick; public event EventHandler? CloseButtonClick; diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index 3bd7e614..f1d0d11c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -3,11 +3,13 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; +using Avalonia.Controls.Notifications; using Avalonia.Threading; using CommunityToolkit.Mvvm.Input; using NLog; @@ -27,6 +29,8 @@ using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Base; @@ -41,6 +45,7 @@ public abstract partial class InferenceGenerationViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + private readonly ISettingsManager settingsManager; private readonly INotificationService notificationService; private readonly ServiceManager vmFactory; @@ -60,11 +65,13 @@ public abstract partial class InferenceGenerationViewModelBase protected InferenceGenerationViewModelBase( ServiceManager vmFactory, IInferenceClientManager inferenceClientManager, - INotificationService notificationService + INotificationService notificationService, + ISettingsManager settingsManager ) : base(notificationService) { this.notificationService = notificationService; + this.settingsManager = settingsManager; this.vmFactory = vmFactory; ClientManager = inferenceClientManager; @@ -75,6 +82,101 @@ public abstract partial class InferenceGenerationViewModelBase GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService); } + /// + /// Write an image to the default output folder + /// + protected Task WriteOutputImageAsync( + Stream imageStream, + ImageGenerationEventArgs args, + int batchNum = 0, + int batchTotal = 0, + bool isGrid = false + ) + { + var defaultOutputDir = settingsManager.ImagesInferenceDirectory; + defaultOutputDir.Create(); + + return WriteOutputImageAsync( + imageStream, + defaultOutputDir, + args, + batchNum, + batchTotal, + isGrid + ); + } + + /// + /// Write an image to an output folder + /// + protected async Task WriteOutputImageAsync( + Stream imageStream, + DirectoryPath outputDir, + ImageGenerationEventArgs args, + int batchNum = 0, + int batchTotal = 0, + bool isGrid = false + ) + { + var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat; + + var formatProvider = new FileNameFormatProvider + { + GenerationParameters = args.Parameters, + ProjectType = args.Project?.ProjectType, + ProjectName = ProjectFile?.NameWithoutExtension + }; + + // Parse to format + if ( + string.IsNullOrEmpty(formatTemplateStr) + || !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format) + ) + { + // Fallback to default + Logger.Warn( + "Failed to parse format template: {FormatTemplate}, using default", + formatTemplateStr + ); + + format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider); + } + + if (isGrid) + { + format = format.WithGridPrefix(); + } + + if (batchNum >= 1 && batchTotal > 1) + { + format = format.WithBatchPostFix(batchNum, batchTotal); + } + + var fileName = format.GetFileName(); + var file = outputDir.JoinFile($"{fileName}.png"); + + // Until the file is free, keep adding _{i} to the end + for (var i = 0; i < 100; i++) + { + if (!file.Exists) + break; + + file = outputDir.JoinFile($"{fileName}_{i + 1}.png"); + } + + // If that fails, append an 7-char uuid + if (file.Exists) + { + var uuid = Guid.NewGuid().ToString("N")[..7]; + file = outputDir.JoinFile($"{fileName}_{uuid}.png"); + } + + await using var fileStream = file.Info.OpenWrite(); + await imageStream.CopyToAsync(fileStream); + + return file; + } + /// /// Builds the image generation prompt /// @@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase // Wait for prompt to finish await promptTask.Task.WaitAsync(cancellationToken); - Logger.Trace($"Prompt task {promptTask.Id} finished"); + Logger.Debug($"Prompt task {promptTask.Id} finished"); // Get output images var imageOutputs = await client.GetImagesForExecutedPromptAsync( @@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase cancellationToken ); + if ( + !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) + || images is not { Count: > 0 } + ) + { + // No images match + notificationService.Show( + "No output", + "Did not receive any output images", + NotificationType.Warning + ); + return; + } + // Disable cancellation await promptInterrupt.DisposeAsync(); @@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase ImageGalleryCardViewModel.ImageSources.Clear(); } - if ( - !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null - ) - { - // No images match - notificationService.Show("No output", "Did not receive any output images"); - return; - } - await ProcessOutputImages(images, args); } finally @@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase ImageGenerationEventArgs args ) { + var client = args.Client; + // Write metadata to images + var outputImagesBytes = new List(); var outputImages = new List(); - foreach ( - var (i, filePath) in images - .Select(image => image.ToFilePath(args.Client.OutputImagesDir!)) - .Enumerate() - ) + + foreach (var (i, comfyImage) in images.Enumerate()) { - if (!filePath.Exists) - { - Logger.Warn($"Image file {filePath} does not exist"); - continue; - } + Logger.Debug("Downloading image: {FileName}", comfyImage.FileName); + var imageStream = await client.GetImageStreamAsync(comfyImage); + + using var ms = new MemoryStream(); + await imageStream.CopyToAsync(ms); + + var imageArray = ms.ToArray(); + outputImagesBytes.Add(imageArray); var parameters = args.Parameters!; var project = args.Project!; @@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase ); } - var bytesWithMetadata = PngDataHelper.AddMetadata( - await filePath.ReadAllBytesAsync(), - parameters, - project - ); + var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project); - await using (var outputStream = filePath.Info.OpenWrite()) - { - await outputStream.WriteAsync(bytesWithMetadata); - await outputStream.FlushAsync(); - } + // Write using generated name + var filePath = await WriteOutputImageAsync( + new MemoryStream(bytesWithMetadata), + args, + i + 1, + images.Count + ); outputImages.Add(new ImageSource(filePath)); @@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase // Download all images to make grid, if multiple if (outputImages.Count > 1) { - var outputDir = outputImages[0].LocalFile!.Directory; - - var loadedImages = outputImages - .Select(i => i.LocalFile) - .Where(f => f is { Exists: true }) - .Select(f => - { - using var stream = f!.Info.OpenRead(); - return SKImage.FromEncodedData(stream); - }) - .ToImmutableArray(); + var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray(); var project = args.Project!; @@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase ); // Save to disk - var lastName = outputImages.Last().LocalFile?.Info.Name; - var gridPath = outputDir!.JoinFile($"grid-{lastName}"); - - await using (var fileStream = gridPath.Info.OpenWrite()) - { - await fileStream.WriteAsync(gridBytesWithMetadata); - } + var gridPath = await WriteOutputImageAsync( + new MemoryStream(gridBytesWithMetadata), + args, + isGrid: true + ); // Insert to start of images var gridImage = new ImageSource(gridPath); diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs index b31b5501..bc1d60ff 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs @@ -310,7 +310,7 @@ public abstract partial class InferenceTabViewModelBase if (this is IImageGalleryComponent imageGalleryComponent) { imageGalleryComponent.LoadImagesToGallery( - new ImageSource(imageFile.GlobalFullPath) + new ImageSource(imageFile.AbsolutePath) ); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs index e4e458d8..0d06e87f 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs @@ -1,5 +1,4 @@ using System; -using System.Threading.Tasks; using CommunityToolkit.Mvvm.ComponentModel; namespace StabilityMatrix.Avalonia.ViewModels.Base; diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs index a4785ae6..43288d83 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs @@ -17,6 +17,7 @@ using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api; @@ -27,6 +28,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; +[ManagedService] +[Transient] public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -254,16 +257,17 @@ public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel private static string PruneDescription(CivitModel model) { - var prunedDescription = model.Description - .Replace("
", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("
", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("

", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("", $"{Environment.NewLine}{Environment.NewLine}") - .Replace("", $"{Environment.NewLine}{Environment.NewLine}"); + var prunedDescription = + model.Description + ?.Replace("
", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("
", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("

", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("", $"{Environment.NewLine}{Environment.NewLine}") + .Replace("", $"{Environment.NewLine}{Environment.NewLine}") ?? string.Empty; prunedDescription = HtmlRegex().Replace(prunedDescription, string.Empty); return prunedDescription; } diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs index 2ecfc90e..e7659258 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs @@ -5,16 +5,13 @@ using System.ComponentModel; using System.Diagnostics; using System.Linq; using System.Net.Http; -using System.Reactive; using System.Reactive.Linq; -using System.Reactive.Threading.Tasks; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Collections; using Avalonia.Controls; using Avalonia.Controls.Notifications; -using AvaloniaEdit.Utils; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; @@ -42,6 +39,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(CheckpointBrowserPage))] +[Singleton] public partial class CheckpointBrowserViewModel : PageViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs index f1260f93..a978d874 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs @@ -9,7 +9,9 @@ using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using NLog; +using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; @@ -19,6 +21,8 @@ using StabilityMatrix.Core.Processes; namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager; +[ManagedService] +[Transient] public partial class CheckpointFile : ViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -257,6 +261,11 @@ public partial class CheckpointFile : ViewModelBase .Where(File.Exists) .FirstOrDefault(); + if (string.IsNullOrWhiteSpace(checkpointFile.PreviewImagePath)) + { + checkpointFile.PreviewImagePath = Assets.NoImage.ToString(); + } + yield return checkpointFile; } } diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs index 398aaf64..ecff1c02 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs @@ -17,6 +17,7 @@ using DynamicData.Binding; using FluentAvalonia.UI.Controls; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; @@ -27,6 +28,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager; +[ManagedService] +[Transient] public partial class CheckpointFolder : ViewModelBase { private readonly ISettingsManager settingsManager; @@ -99,7 +102,7 @@ public partial class CheckpointFolder : ViewModelBase public IObservableCollection CheckpointFiles { get; } = new ObservableCollectionExtended(); - public IObservableCollection DisplayedCheckpointFiles { get; } = + public IObservableCollection DisplayedCheckpointFiles { get; set; } = new ObservableCollectionExtended(); public CheckpointFolder( diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs index c77b92d2..86b7fdd3 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs @@ -24,6 +24,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(CheckpointsPage))] +[Singleton] public partial class CheckpointsPageViewModel : PageViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs index aaf2aaa0..47625eaf 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs @@ -16,6 +16,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(DownloadResourceDialog))] +[ManagedService] +[Transient] public partial class DownloadResourceViewModel : ContentDialogViewModelBase { private readonly IDownloadService downloadService; diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs index 55fa73b5..cab1a11a 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs @@ -12,6 +12,8 @@ using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(EnvVarsViewModel))] +[ManagedService] +[Transient] public partial class EnvVarsViewModel : ContentDialogViewModelBase { [ObservableProperty] diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs index 5c0b9427..42e8a6d6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs @@ -6,11 +6,13 @@ using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(ExceptionDialog))] +[ManagedService] +[Transient] public partial class ExceptionViewModel : ViewModelBase { public Exception? Exception { get; set; } - + public string? Message => Exception?.Message; - + public string? ExceptionType => Exception?.GetType().Name ?? ""; } diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs index 821d8f5b..c07824e7 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs @@ -21,6 +21,8 @@ using Size = StabilityMatrix.Core.Helper.Size; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(ImageViewerDialog))] +[ManagedService] +[Transient] public partial class ImageViewerViewModel : ContentDialogViewModelBase { [ObservableProperty] diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs index ac0d1dc1..351f9a64 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs @@ -6,7 +6,6 @@ using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; -using StabilityMatrix.Avalonia.Animations; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Services; @@ -23,6 +22,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(InferenceConnectionHelpDialog))] +[ManagedService] +[Transient] public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBase { private readonly ISettingsManager settingsManager; diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs index 33f92b9d..d156def6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs @@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Factory; using StabilityMatrix.Core.Models; @@ -31,11 +32,14 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; +[ManagedService] +[Transient] public partial class InstallerViewModel : ContentDialogViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly ISettingsManager settingsManager; + private readonly IPackageFactory packageFactory; private readonly IPyRunner pyRunner; private readonly IDownloadService downloadService; private readonly INotificationService notificationService; @@ -47,15 +51,15 @@ public partial class InstallerViewModel : ContentDialogViewModelBase [ObservableProperty] private PackageVersion? selectedVersion; - [ObservableProperty] - private IReadOnlyList? availablePackages; - [ObservableProperty] private ObservableCollection? availableCommits; [ObservableProperty] private ObservableCollection? availableVersions; + [ObservableProperty] + private ObservableCollection availablePackages; + [ObservableProperty] private GitCommit? selectedCommit; @@ -69,6 +73,9 @@ public partial class InstallerViewModel : ContentDialogViewModelBase [NotifyPropertyChangedFor(nameof(CanInstall))] private bool showDuplicateWarning; + [ObservableProperty] + private bool showIncompatiblePackages; + [ObservableProperty] [NotifyPropertyChangedFor(nameof(CanInstall))] private string? installName; @@ -115,7 +122,6 @@ public partial class InstallerViewModel : ContentDialogViewModelBase public bool CanInstall => !string.IsNullOrWhiteSpace(InstallName) && !ShowDuplicateWarning && !IsLoading; - public ProgressViewModel InstallProgress { get; } = new(); public IEnumerable Steps { get; set; } public InstallerViewModel( @@ -128,22 +134,25 @@ public partial class InstallerViewModel : ContentDialogViewModelBase ) { this.settingsManager = settingsManager; + this.packageFactory = packageFactory; this.pyRunner = pyRunner; this.downloadService = downloadService; this.notificationService = notificationService; this.prerequisiteHelper = prerequisiteHelper; - // AvailablePackages and SelectedPackage + var filtered = packageFactory.GetAllAvailablePackages().Where(p => p.IsCompatible).ToList(); + AvailablePackages = new ObservableCollection( - packageFactory.GetAllAvailablePackages() + filtered.Any() ? filtered : packageFactory.GetAllAvailablePackages() ); - SelectedPackage = AvailablePackages[0]; + ShowIncompatiblePackages = !filtered.Any(); } public override void OnLoaded() { if (AvailablePackages == null) return; + IsReleaseMode = !SelectedPackage.ShouldIgnoreReleases; } @@ -238,6 +247,8 @@ public partial class InstallerViewModel : ContentDialogViewModelBase downloadOptions.VersionTag = SelectedVersion?.TagName ?? throw new NullReferenceException("Selected version is null"); + downloadOptions.IsLatest = + AvailableVersions?.First().TagName == downloadOptions.VersionTag; installedVersion.InstalledReleaseVersion = downloadOptions.VersionTag; } @@ -245,6 +256,11 @@ public partial class InstallerViewModel : ContentDialogViewModelBase { downloadOptions.CommitHash = SelectedCommit?.Sha ?? throw new NullReferenceException("Selected commit is null"); + downloadOptions.BranchName = + SelectedVersion?.TagName + ?? throw new NullReferenceException("Selected version is null"); + downloadOptions.IsLatest = AvailableCommits?.First().Sha == SelectedCommit.Sha; + installedVersion.InstalledBranch = SelectedVersion?.TagName ?? throw new NullReferenceException("Selected version is null"); @@ -259,6 +275,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase var installStep = new InstallPackageStep( SelectedPackage, SelectedTorchVersion, + downloadOptions, installLocation ); var setupModelFoldersStep = new SetupModelFoldersStep( @@ -301,12 +318,29 @@ public partial class InstallerViewModel : ContentDialogViewModelBase OnCloseButtonClick(); } + partial void OnShowIncompatiblePackagesChanged(bool value) + { + var filtered = packageFactory + .GetAllAvailablePackages() + .Where(p => ShowIncompatiblePackages || p.IsCompatible) + .ToList(); + + AvailablePackages = new ObservableCollection( + filtered.Any() ? filtered : packageFactory.GetAllAvailablePackages() + ); + SelectedPackage = AvailablePackages[0]; + } + private void UpdateSelectedVersionToLatestMain() { if (AvailableVersions is null) { SelectedVersion = null; } + else if (SelectedPackage is FooocusMre) + { + SelectedVersion = AvailableVersions.FirstOrDefault(x => x.TagName == "moonride-main"); + } else { // First try to find master @@ -358,41 +392,33 @@ public partial class InstallerViewModel : ContentDialogViewModelBase // When changing branch / release modes, refresh // ReSharper disable once UnusedParameterInPartialMethod - partial void OnSelectedVersionTypeChanged(PackageVersionType value) => - OnSelectedPackageChanged(SelectedPackage); - - partial void OnSelectedPackageChanged(BasePackage value) + partial void OnSelectedVersionTypeChanged(PackageVersionType value) { - IsLoading = true; - ReleaseNotes = string.Empty; - AvailableVersions?.Clear(); - AvailableCommits?.Clear(); - - AvailableVersionTypes = SelectedPackage.ShouldIgnoreReleases - ? PackageVersionType.Commit - : PackageVersionType.GithubRelease | PackageVersionType.Commit; - SelectedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod; - SelectedTorchVersion = SelectedPackage.GetRecommendedTorchVersion(); - if (Design.IsDesignMode) + if (SelectedPackage is null || Design.IsDesignMode) return; Dispatcher.UIThread .InvokeAsync(async () => { Logger.Debug($"Release mode: {IsReleaseMode}"); - var versionOptions = await value.GetAllVersionOptions(); + var versionOptions = await SelectedPackage.GetAllVersionOptions(); AvailableVersions = IsReleaseMode ? new ObservableCollection(versionOptions.AvailableVersions) : new ObservableCollection(versionOptions.AvailableBranches); - SelectedVersion = AvailableVersions.First(x => !x.IsPrerelease); + SelectedVersion = AvailableVersions?.FirstOrDefault(x => !x.IsPrerelease); + if (SelectedVersion is null) + return; + ReleaseNotes = SelectedVersion.ReleaseNotesMarkdown; Logger.Debug($"Loaded release notes for {ReleaseNotes}"); if (!IsReleaseMode) { - var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList(); + var commits = ( + await SelectedPackage.GetAllCommits(SelectedVersion.TagName) + )?.ToList(); if (commits is null || commits.Count == 0) return; @@ -408,6 +434,29 @@ public partial class InstallerViewModel : ContentDialogViewModelBase .SafeFireAndForget(); } + partial void OnSelectedPackageChanged(BasePackage? value) + { + IsLoading = true; + ReleaseNotes = string.Empty; + AvailableVersions?.Clear(); + AvailableCommits?.Clear(); + + if (value == null) + return; + + AvailableVersionTypes = SelectedPackage.ShouldIgnoreReleases + ? PackageVersionType.Commit + : PackageVersionType.GithubRelease | PackageVersionType.Commit; + IsReleaseMode = !SelectedPackage.ShouldIgnoreReleases; + SelectedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod; + SelectedTorchVersion = SelectedPackage.GetRecommendedTorchVersion(); + SelectedVersionType = SelectedPackage.ShouldIgnoreReleases + ? PackageVersionType.Commit + : PackageVersionType.GithubRelease; + + OnSelectedVersionTypeChanged(SelectedVersionType); + } + partial void OnInstallNameChanged(string? value) { ShowDuplicateWarning = settingsManager.Settings.InstalledPackages.Any( @@ -418,7 +467,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase partial void OnSelectedVersionChanged(PackageVersion? value) { ReleaseNotes = value?.ReleaseNotesMarkdown ?? string.Empty; - if (value == null) + if (value == null || Design.IsDesignMode) return; SelectedCommit = null; diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs index a4d2e4a6..303d4462 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs @@ -17,6 +17,8 @@ using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(LaunchOptionsDialog))] +[ManagedService] +[Transient] public partial class LaunchOptionsViewModel : ContentDialogViewModelBase { private readonly ILogger logger; diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs index b4d4a823..155e2297 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.IO; using System.Linq; @@ -7,25 +8,29 @@ using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using Microsoft.Extensions.Logging; using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Factory; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.PackageModification; using StabilityMatrix.Core.Models.Packages; -using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; -public partial class OneClickInstallViewModel : ViewModelBase +[ManagedService] +[Transient] +public partial class OneClickInstallViewModel : ContentDialogViewModelBase { private readonly ISettingsManager settingsManager; private readonly IPackageFactory packageFactory; private readonly IPrerequisiteHelper prerequisiteHelper; private readonly ILogger logger; private readonly IPyRunner pyRunner; - private readonly ISharedFolders sharedFolders; + private readonly INavigationService navigationService; private const string DefaultPackageName = "stable-diffusion-webui"; [ObservableProperty] @@ -43,6 +48,9 @@ public partial class OneClickInstallViewModel : ViewModelBase [ObservableProperty] private bool isIndeterminate; + [ObservableProperty] + private bool showIncompatiblePackages; + [ObservableProperty] private ObservableCollection allPackages; @@ -53,6 +61,8 @@ public partial class OneClickInstallViewModel : ViewModelBase [NotifyPropertyChangedFor(nameof(IsProgressBarVisible))] private int oneClickInstallProgress; + private bool isInferenceInstall; + public bool IsProgressBarVisible => OneClickInstallProgress > 0 || IsIndeterminate; public OneClickInstallViewModel( @@ -61,7 +71,7 @@ public partial class OneClickInstallViewModel : ViewModelBase IPrerequisiteHelper prerequisiteHelper, ILogger logger, IPyRunner pyRunner, - ISharedFolders sharedFolders + INavigationService navigationService ) { this.settingsManager = settingsManager; @@ -69,13 +79,21 @@ public partial class OneClickInstallViewModel : ViewModelBase this.prerequisiteHelper = prerequisiteHelper; this.logger = logger; this.pyRunner = pyRunner; - this.sharedFolders = sharedFolders; + this.navigationService = navigationService; HeaderText = Resources.Text_WelcomeToStabilityMatrix; SubHeaderText = Resources.Text_OneClickInstaller_SubHeader; ShowInstallButton = true; + + var filteredPackages = this.packageFactory + .GetAllAvailablePackages() + .Where(p => p is { OfferInOneClickInstaller: true, IsCompatible: true }) + .ToList(); + AllPackages = new ObservableCollection( - this.packageFactory.GetAllAvailablePackages().Where(p => p.OfferInOneClickInstaller) + filteredPackages.Any() + ? filteredPackages + : this.packageFactory.GetAllAvailablePackages() ); SelectedPackage = AllPackages[0]; } @@ -95,39 +113,34 @@ public partial class OneClickInstallViewModel : ViewModelBase return Task.CompletedTask; } - private async Task DoInstall() + [RelayCommand] + private async Task InstallComfyForInference() { - HeaderText = $"{Resources.Label_Installing} {SelectedPackage.DisplayName}"; - - var progressHandler = new Progress(progress => - { - SubHeaderText = $"{progress.Title} {progress.Percentage:N0}%"; - - IsIndeterminate = progress.IsIndeterminate; - OneClickInstallProgress = Convert.ToInt32(progress.Percentage); - }); - - await prerequisiteHelper.InstallAllIfNecessary(progressHandler); - - SubHeaderText = Resources.Progress_InstallingPrerequisites; - IsIndeterminate = true; - if (!PyRunner.PipInstalled) + var comfyPackage = AllPackages.FirstOrDefault(x => x is ComfyUI); + if (comfyPackage != null) { - await pyRunner.SetupPip(); + SelectedPackage = comfyPackage; + isInferenceInstall = true; + await InstallCommand.ExecuteAsync(null); } + } - if (!PyRunner.VenvInstalled) + private async Task DoInstall() + { + var steps = new List { - await pyRunner.InstallPackage("virtualenv"); - } - IsIndeterminate = false; - - var libraryDir = settingsManager.LibraryDir; + new SetPackageInstallingStep(settingsManager, SelectedPackage.Name), + new SetupPrerequisitesStep(prerequisiteHelper, pyRunner) + }; // get latest version & download & install - var installLocation = Path.Combine(libraryDir, "Packages", SelectedPackage.Name); + var installLocation = Path.Combine( + settingsManager.LibraryDir, + "Packages", + SelectedPackage.Name + ); - var downloadVersion = new DownloadPackageVersionOptions(); + var downloadVersion = new DownloadPackageVersionOptions { IsLatest = true }; var installedVersion = new InstalledPackageVersion(); var versionOptions = await SelectedPackage.GetAllVersionOptions(); @@ -139,16 +152,39 @@ public partial class OneClickInstallViewModel : ViewModelBase else { downloadVersion.BranchName = await SelectedPackage.GetLatestVersion(); + downloadVersion.CommitHash = + (await SelectedPackage.GetAllCommits(downloadVersion.BranchName)) + ?.FirstOrDefault() + ?.Sha ?? string.Empty; + installedVersion.InstalledBranch = downloadVersion.BranchName; + installedVersion.InstalledCommitSha = downloadVersion.CommitHash; } var torchVersion = SelectedPackage.GetRecommendedTorchVersion(); + var recommendedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod; - await DownloadPackage(installLocation, downloadVersion); - await InstallPackage(installLocation, torchVersion); + var downloadStep = new DownloadPackageVersionStep( + SelectedPackage, + installLocation, + downloadVersion + ); + steps.Add(downloadStep); - var recommendedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod; - await SelectedPackage.SetupModelFolders(installLocation, recommendedSharedFolderMethod); + var installStep = new InstallPackageStep( + SelectedPackage, + torchVersion, + downloadVersion, + installLocation + ); + steps.Add(installStep); + + var setupModelFoldersStep = new SetupModelFoldersStep( + SelectedPackage, + recommendedSharedFolderMethod, + installLocation + ); + steps.Add(setupModelFoldersStep); var installedPackage = new InstalledPackage { @@ -162,59 +198,47 @@ public partial class OneClickInstallViewModel : ViewModelBase PreferredTorchVersion = torchVersion, PreferredSharedFolderMethod = recommendedSharedFolderMethod }; - await using var st = settingsManager.BeginTransaction(); - st.Settings.InstalledPackages.Add(installedPackage); - st.Settings.ActiveInstalledPackageId = installedPackage.Id; - EventManager.Instance.OnInstalledPackagesChanged(); - HeaderText = Resources.Progress_InstallationComplete; - SubSubHeaderText = string.Empty; - OneClickInstallProgress = 100; + var addInstalledPackageStep = new AddInstalledPackageStep( + settingsManager, + installedPackage + ); + steps.Add(addInstalledPackageStep); - for (var i = 0; i < 3; i++) + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + HideCloseButton = true, + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + await runner.ExecuteSteps(steps); + + EventManager.Instance.OnInstalledPackagesChanged(); + HeaderText = $"{SelectedPackage.DisplayName} installed successfully"; + for (var i = 3; i > 0; i--) { - SubHeaderText = $"{Resources.Text_ProceedingToLaunchPage} ({i + 1}s)"; + SubHeaderText = $"{Resources.Text_ProceedingToLaunchPage} ({i}s)"; await Task.Delay(1000); } // should close dialog EventManager.Instance.OnOneClickInstallFinished(false); - } - - private async Task DownloadPackage( - string installLocation, - DownloadPackageVersionOptions versionOptions - ) - { - SubHeaderText = Resources.Progress_DownloadingPackage; - - var progress = new Progress(progress => + if (isInferenceInstall) { - IsIndeterminate = progress.IsIndeterminate; - OneClickInstallProgress = Convert.ToInt32(progress.Percentage); - EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress); - }); - - await SelectedPackage.DownloadPackage(installLocation, versionOptions, progress); - SubHeaderText = Resources.Progress_DownloadComplete; - OneClickInstallProgress = 100; + navigationService.NavigateTo(); + } } - private async Task InstallPackage(string installLocation, TorchVersion torchVersion) + partial void OnShowIncompatiblePackagesChanged(bool value) { - var progress = new Progress(progress => - { - SubHeaderText = Resources.Progress_InstallingPackageRequirements; - IsIndeterminate = progress.IsIndeterminate; - OneClickInstallProgress = Convert.ToInt32(progress.Percentage); - EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress); - }); + var filteredPackages = packageFactory + .GetAllAvailablePackages() + .Where(p => p.OfferInOneClickInstaller && (ShowIncompatiblePackages || p.IsCompatible)) + .ToList(); - await SelectedPackage.InstallPackage( - installLocation, - torchVersion, - progress, - (output) => SubSubHeaderText = output.Text + AllPackages = new ObservableCollection( + filteredPackages.Any() ? filteredPackages : packageFactory.GetAllAvailablePackages() ); + SelectedPackage = AllPackages[0]; } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs index 915c2a2f..06ff94de 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs @@ -23,6 +23,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(PackageImportDialog))] +[ManagedService] +[Transient] public partial class PackageImportViewModel : ContentDialogViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs new file mode 100644 index 00000000..7259df54 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs @@ -0,0 +1,129 @@ +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Avalonia.Controls; +using CommunityToolkit.Mvvm.ComponentModel; +using Semver; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +public partial class PythonPackagesItemViewModel : ViewModelBase +{ + [ObservableProperty] + private PipPackageInfo package; + + [ObservableProperty] + private string? selectedVersion; + + [ObservableProperty] + private IReadOnlyList? availableVersions; + + [ObservableProperty] + private PipShowResult? pipShowResult; + + [ObservableProperty] + private bool isLoading; + + /// + /// True if selected version is newer than the installed version + /// + [ObservableProperty] + private bool canUpgrade; + + /// + /// True if selected version is older than the installed version + /// + [ObservableProperty] + private bool canDowngrade; + + partial void OnSelectedVersionChanged(string? value) + { + if ( + value is null + || Package.Version == value + || !SemVersion.TryParse(Package.Version, out var currentSemver) + || !SemVersion.TryParse(value, out var selectedSemver) + ) + { + CanUpgrade = false; + CanDowngrade = false; + return; + } + + var precedence = selectedSemver.ComparePrecedenceTo(currentSemver); + + CanUpgrade = precedence > 0; + CanDowngrade = precedence < 0; + } + + /// + /// Return the known index URL for a package, currently this is torch, torchvision and torchaudio + /// + public static string? GetKnownIndexUrl(string packageName, string version) + { + var torchPackages = new[] { "torch", "torchvision", "torchaudio" }; + if (torchPackages.Contains(packageName) && version.Contains('+')) + { + // Get the metadata for the current version (everything after the +) + var indexName = version.Split('+', 2).Last(); + + var indexUrl = $"https://download.pytorch.org/whl/{indexName}"; + return indexUrl; + } + + return null; + } + + /// + /// Loads the pip show result if not already loaded + /// + public async Task LoadExtraInfo(DirectoryPath venvPath) + { + if (PipShowResult is not null) + { + return; + } + + IsLoading = true; + + try + { + if (Design.IsDesignMode) + { + await LoadExtraInfoDesignMode(); + } + else + { + await using var venvRunner = new PyVenvRunner(venvPath); + + PipShowResult = await venvRunner.PipShow(Package.Name); + + // Attempt to get known index url + var indexUrl = GetKnownIndexUrl(Package.Name, Package.Version); + + if (await venvRunner.PipIndex(Package.Name, indexUrl) is { } pipIndexResult) + { + AvailableVersions = pipIndexResult.AvailableVersions; + SelectedVersion = Package.Version; + } + } + } + finally + { + IsLoading = false; + } + } + + private async Task LoadExtraInfoDesignMode() + { + await using var _ = new MinimumDelay(200, 300); + + PipShowResult = new PipShowResult { Name = Package.Name, Version = Package.Version }; + AvailableVersions = new[] { Package.Version, "1.2.0", "1.1.0", "1.0.0" }; + SelectedVersion = Package.Version; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesViewModel.cs new file mode 100644 index 00000000..b7d00f8e --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesViewModel.cs @@ -0,0 +1,313 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using System.Threading; +using System.Threading.Tasks; +using AsyncAwaitBestPractices; +using Avalonia.Controls; +using Avalonia.Controls.Primitives; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using DynamicData; +using DynamicData.Binding; +using FluentAvalonia.UI.Controls; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.PackageModification; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +[View(typeof(PythonPackagesDialog))] +[ManagedService] +[Transient] +public partial class PythonPackagesViewModel : ContentDialogViewModelBase +{ + public DirectoryPath? VenvPath { get; set; } + + [ObservableProperty] + private bool isLoading; + + private readonly SourceCache packageSource = new(p => p.Name); + + public IObservableCollection Packages { get; } = + new ObservableCollectionExtended(); + + [ObservableProperty] + private PythonPackagesItemViewModel? selectedPackage; + + public PythonPackagesViewModel() + { + packageSource + .Connect() + .DeferUntilLoaded() + .Transform(p => new PythonPackagesItemViewModel { Package = p }) + .SortBy(vm => vm.Package.Name) + .Bind(Packages) + .Subscribe(); + } + + private async Task Refresh() + { + if (VenvPath is null) + return; + + IsLoading = true; + + try + { + if (Design.IsDesignMode) + { + await Task.Delay(250); + } + else + { + await using var venvRunner = new PyVenvRunner(VenvPath); + + var packages = await venvRunner.PipList(); + packageSource.EditDiff(packages); + } + } + finally + { + IsLoading = false; + } + } + + [RelayCommand] + private async Task RefreshBackground() + { + if (VenvPath is null) + return; + + await using var venvRunner = new PyVenvRunner(VenvPath); + + var packages = await venvRunner.PipList(); + + Dispatcher.UIThread.Post(() => + { + // Backup selected package + var currentPackageName = SelectedPackage?.Package.Name; + + packageSource.EditDiff(packages); + + // Restore selected package + SelectedPackage = Packages.FirstOrDefault(p => p.Package.Name == currentPackageName); + }); + } + + /// + /// Load the selected package's show info if not already loaded + /// + partial void OnSelectedPackageChanged(PythonPackagesItemViewModel? value) + { + if (value is null) + { + return; + } + + if (value.PipShowResult is null) + { + value.LoadExtraInfo(VenvPath!).SafeFireAndForget(); + } + } + + /// + public override Task OnLoadedAsync() + { + return Refresh(); + } + + public void AddPackages(params PipPackageInfo[] packages) + { + packageSource.AddOrUpdate(packages); + } + + [RelayCommand] + private Task ModifySelectedPackage(PythonPackagesItemViewModel? item) + { + return item?.SelectedVersion != null + ? UpgradePackageVersion( + item.Package.Name, + item.SelectedVersion, + PythonPackagesItemViewModel.GetKnownIndexUrl( + item.Package.Name, + item.SelectedVersion + ), + isDowngrade: item.CanDowngrade + ) + : Task.CompletedTask; + } + + private async Task UpgradePackageVersion( + string packageName, + string version, + string? extraIndexUrl = null, + bool isDowngrade = false + ) + { + if (VenvPath is null || SelectedPackage?.Package is not { } package) + return; + + // Confirmation dialog + var dialog = DialogHelper.CreateMarkdownDialog( + isDowngrade + ? $"Downgrade **{package.Name}** to **{version}**?" + : $"Upgrade **{package.Name}** to **{version}**?", + Resources.Label_ConfirmQuestion + ); + + dialog.PrimaryButtonText = isDowngrade + ? Resources.Action_Downgrade + : Resources.Action_Upgrade; + dialog.IsPrimaryButtonEnabled = true; + dialog.DefaultButton = ContentDialogButton.Primary; + dialog.CloseButtonText = Resources.Action_Cancel; + + if (await dialog.ShowAsync() is not ContentDialogResult.Primary) + { + return; + } + + var args = new ProcessArgsBuilder("install", $"{packageName}=={version}"); + + if (extraIndexUrl != null) + { + args = args.AddArg(("--extra-index-url", extraIndexUrl)); + } + + var steps = new List + { + new PipStep + { + VenvDirectory = VenvPath, + WorkingDirectory = VenvPath.Parent, + Args = args + } + }; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = isDowngrade + ? $"Downgraded Python Package '{packageName}' to {version}" + : $"Upgraded Python Package '{packageName}' to {version}" + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + await runner.ExecuteSteps(steps); + + // Refresh + RefreshBackground().SafeFireAndForget(); + } + + [RelayCommand] + private async Task InstallPackage() + { + if (VenvPath is null) + return; + + // Dialog + var fields = new TextBoxField[] + { + new() { Label = "Package Name", InnerLeftText = "pip install" } + }; + + var dialog = DialogHelper.CreateTextEntryDialog("Install Package", "", fields); + var result = await dialog.ShowAsync(); + + if (result != ContentDialogResult.Primary || fields[0].Text is not { } packageName) + { + return; + } + + var steps = new List + { + new PipStep + { + VenvDirectory = VenvPath, + WorkingDirectory = VenvPath.Parent, + Args = new[] { "install", packageName } + } + }; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = $"Installed Python Package '{packageName}'" + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + await runner.ExecuteSteps(steps); + + // Refresh + RefreshBackground().SafeFireAndForget(); + } + + [RelayCommand] + private async Task UninstallSelectedPackage() + { + if (VenvPath is null || SelectedPackage?.Package is not { } package) + return; + + // Confirmation dialog + var dialog = DialogHelper.CreateMarkdownDialog( + $"This will uninstall the package '{package.Name}'", + Resources.Label_ConfirmQuestion + ); + dialog.PrimaryButtonText = Resources.Action_Uninstall; + dialog.IsPrimaryButtonEnabled = true; + dialog.DefaultButton = ContentDialogButton.Primary; + dialog.CloseButtonText = Resources.Action_Cancel; + + if (await dialog.ShowAsync() is not ContentDialogResult.Primary) + { + return; + } + + var steps = new List + { + new PipStep + { + VenvDirectory = VenvPath, + WorkingDirectory = VenvPath.Parent, + Args = new[] { "uninstall", "--yes", package.Name } + } + }; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = $"Uninstalled Python Package '{package.Name}'" + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + await runner.ExecuteSteps(steps); + + // Refresh + RefreshBackground().SafeFireAndForget(); + } + + public BetterContentDialog GetDialog() + { + return new BetterContentDialog + { + CloseOnClickOutside = true, + MinDialogWidth = 800, + MaxDialogWidth = 1500, + FullSizeDesired = true, + ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled, + Title = Resources.Label_PythonPackages, + Content = new PythonPackagesDialog { DataContext = this }, + CloseButtonText = Resources.Action_Close, + DefaultButton = ContentDialogButton.Close + }; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs index 15fa1e48..9c20bdf0 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs @@ -19,6 +19,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(SelectDataDirectoryDialog))] +[ManagedService] +[Transient] public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs index 77eac510..c3f10931 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs @@ -1,19 +1,19 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; -using System.Threading.Tasks; using Avalonia.Media.Imaging; -using Avalonia.Platform; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using FluentAvalonia.UI.Controls; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; +[ManagedService] +[Transient] public partial class SelectModelVersionViewModel : ContentDialogViewModelBase { private readonly ISettingsManager settingsManager; diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs index 8710295a..8ee671a1 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs @@ -22,6 +22,8 @@ using StabilityMatrix.Core.Updater; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; [View(typeof(UpdateDialog))] +[ManagedService] +[Singleton] public partial class UpdateViewModel : ContentDialogViewModelBase { private readonly ISettingsManager settingsManager; diff --git a/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs index 25a6faed..54463c21 100644 --- a/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs @@ -11,6 +11,8 @@ using StabilityMatrix.Core.Helper; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(FirstLaunchSetupWindow))] +[ManagedService] +[Singleton] public partial class FirstLaunchSetupViewModel : ViewModelBase { [ObservableProperty] @@ -20,14 +22,17 @@ public partial class FirstLaunchSetupViewModel : ViewModelBase private string gpuInfoText = string.Empty; [ObservableProperty] - private RefreshBadgeViewModel checkHardwareBadge = new() - { - WorkingToolTipText = "We're checking some hardware specifications to determine compatibility.", - SuccessToolTipText = "Everything looks good!", - FailToolTipText = "We recommend a GPU with CUDA support for the best experience. " + - "You can continue without one, but some packages may not work, and inference may be slower.", - FailColorBrush = ThemeColors.ThemeYellow, - }; + private RefreshBadgeViewModel checkHardwareBadge = + new() + { + WorkingToolTipText = + "We're checking some hardware specifications to determine compatibility.", + SuccessToolTipText = "Everything looks good!", + FailToolTipText = + "We recommend a GPU with CUDA support for the best experience. " + + "You can continue without one, but some packages may not work, and inference may be slower.", + FailColorBrush = ThemeColors.ThemeYellow, + }; public FirstLaunchSetupViewModel() { @@ -43,14 +48,16 @@ public partial class FirstLaunchSetupViewModel : ViewModelBase gpuInfo = await Task.Run(() => HardwareHelper.IterGpuInfo().ToArray()); } // First Nvidia GPU - var activeGpu = gpuInfo.FirstOrDefault(gpu => gpu.Name?.ToLowerInvariant().Contains("nvidia") ?? false); + var activeGpu = gpuInfo.FirstOrDefault( + gpu => gpu.Name?.ToLowerInvariant().Contains("nvidia") ?? false + ); var isNvidia = activeGpu is not null; // Otherwise first GPU activeGpu ??= gpuInfo.FirstOrDefault(); GpuInfoText = activeGpu is null ? "No GPU detected" : $"{activeGpu.Name} ({Size.FormatBytes(activeGpu.MemoryBytes)})"; - + return isNvidia; } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs index d18d22bd..12a1d0e9 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs @@ -6,6 +6,8 @@ using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(BatchSizeCard))] +[ManagedService] +[Transient] public partial class BatchSizeCardViewModel : LoadableViewModelBase { [ObservableProperty] diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs index 059975d3..91b09866 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs @@ -7,6 +7,8 @@ using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(FreeUCard))] +[ManagedService] +[Transient] public partial class FreeUCardViewModel : LoadableViewModelBase { [ObservableProperty] diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs index eab2a345..65f9fa0a 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; +using System.Linq; using StabilityMatrix.Avalonia.Models; namespace StabilityMatrix.Avalonia.ViewModels.Inference; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs index e8097834..7502f496 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs @@ -4,7 +4,6 @@ using System.Reactive.Linq; using System.Threading.Tasks; using AsyncAwaitBestPractices; using AsyncImageLoader; -using Avalonia; using Avalonia.Controls.Notifications; using Avalonia.Platform.Storage; using Avalonia.Threading; @@ -26,6 +25,7 @@ using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Services; using SortDirection = DynamicData.Binding.SortDirection; @@ -33,6 +33,8 @@ using SortDirection = DynamicData.Binding.SortDirection; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(ImageFolderCard))] +[ManagedService] +[Transient] public partial class ImageFolderCardViewModel : ViewModelBase { private readonly ILogger logger; @@ -43,6 +45,9 @@ public partial class ImageFolderCardViewModel : ViewModelBase [ObservableProperty] private string? searchQuery; + [ObservableProperty] + private Size imageSize = new(150, 190); + /// /// Collection of local image files /// @@ -61,20 +66,28 @@ public partial class ImageFolderCardViewModel : ViewModelBase this.settingsManager = settingsManager; this.notificationService = notificationService; - var predicate = this.WhenPropertyChanged(vm => vm.SearchQuery) + var searcher = new ImageSearcher(); + + // Observable predicate from SearchQuery changes + var searchPredicate = this.WhenPropertyChanged(vm => vm.SearchQuery) .Throttle(TimeSpan.FromMilliseconds(50))! - .Select, Func>( - p => file => SearchPredicate(file, p.Value) - ) + .Select(property => searcher.GetPredicate(property.Value)) .AsObservable(); imageIndexService.InferenceImages.ItemsSource .Connect() .DeferUntilLoaded() - .Filter(predicate) + .Filter(searchPredicate) .SortBy(file => file.LastModifiedAt, SortDirection.Descending) .Bind(LocalImages) .Subscribe(); + + settingsManager.RelayPropertyFor( + this, + vm => vm.ImageSize, + settings => settings.InferenceImageSize, + delay: TimeSpan.FromMilliseconds(250) + ); } private static bool SearchPredicate(LocalImageFile file, string? query) @@ -116,24 +129,49 @@ public partial class ImageFolderCardViewModel : ViewModelBase public override async Task OnLoadedAsync() { await base.OnLoadedAsync(); - + ImageSize = settingsManager.Settings.InferenceImageSize; imageIndexService.RefreshIndexForAllCollections().SafeFireAndForget(); } + /// + /// Gets the image path if it exists, returns null. + /// If the image path is resolved but the file doesn't exist, it will be removed from the index. + /// + private FilePath? GetImagePathIfExists(LocalImageFile item) + { + var imageFile = new FilePath(item.AbsolutePath); + + if (!imageFile.Exists) + { + // Remove from index + imageIndexService.InferenceImages.Remove(item); + + // Invalidate cache + if (ImageLoader.AsyncImageLoader is FallbackRamCachedWebImageLoader loader) + { + loader.RemoveAllNamesFromCache(imageFile.Name); + } + + return null; + } + + return imageFile; + } + /// /// Handles image clicks to show preview /// [RelayCommand] private async Task OnImageClick(LocalImageFile item) { - if (item.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath) + if (GetImagePathIfExists(item) is not { } imageFile) { return; } var currentIndex = LocalImages.IndexOf(item); - var image = new ImageSource(new FilePath(imagePath)); + var image = new ImageSource(imageFile); // Preload await image.GetBitmapAsync(); @@ -156,14 +194,12 @@ public partial class ImageFolderCardViewModel : ViewModelBase if (newIndex >= 0 && newIndex < LocalImages.Count) { var newImage = LocalImages[newIndex]; - var newImageSource = new ImageSource( - new FilePath(newImage.GetFullPath(settingsManager.ImagesDirectory)) - ); + var newImageSource = new ImageSource(newImage.AbsolutePath); // Preload await newImageSource.GetBitmapAsync(); - var oldImageSource = sender.ImageSource; + // var oldImageSource = sender.ImageSource; sender.ImageSource = newImageSource; sender.LocalImageFile = newImage; @@ -185,13 +221,12 @@ public partial class ImageFolderCardViewModel : ViewModelBase [RelayCommand] private async Task OnImageDelete(LocalImageFile? item) { - if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath) + if (item is null || GetImagePathIfExists(item) is not { } imageFile) { return; } // Delete the file - var imageFile = new FilePath(imagePath); var result = await notificationService.TryAsync(imageFile.DeleteAsync()); if (!result.IsSuccessful) @@ -215,14 +250,14 @@ public partial class ImageFolderCardViewModel : ViewModelBase [RelayCommand] private async Task OnImageCopy(LocalImageFile? item) { - if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath) + if (item is null || GetImagePathIfExists(item) is not { } imageFile) { return; } var clipboard = App.Clipboard; - await clipboard.SetFileDataObjectAsync(imagePath); + await clipboard.SetFileDataObjectAsync(imageFile.FullPath); } /// @@ -231,12 +266,12 @@ public partial class ImageFolderCardViewModel : ViewModelBase [RelayCommand] private async Task OnImageOpen(LocalImageFile? item) { - if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath) + if (item is null || GetImagePathIfExists(item) is not { } imageFile) { return; } - await ProcessRunner.OpenFileBrowser(imagePath); + await ProcessRunner.OpenFileBrowser(imageFile); } /// @@ -248,13 +283,11 @@ public partial class ImageFolderCardViewModel : ViewModelBase bool includeMetadata = false ) { - if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } sourcePath) + if (item is null || GetImagePathIfExists(item) is not { } sourceFile) { return; } - var sourceFile = new FilePath(sourcePath); - var formatName = format.ToString(); var storageFile = await App.StorageProvider.SaveFilePickerAsync( diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs index a8f5920c..93497d43 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs @@ -23,6 +23,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(ImageGalleryCard))] +[ManagedService] +[Transient] public partial class ImageGalleryCardViewModel : ViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs index 9e56a16d..d2a196e6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs @@ -1,13 +1,10 @@ using System; using System.Diagnostics.CodeAnalysis; -using System.Drawing; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; -using Avalonia.Controls.Shapes; -using Avalonia.Threading; using DynamicData.Binding; using NLog; using StabilityMatrix.Avalonia.Extensions; @@ -19,6 +16,7 @@ using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; +using StabilityMatrix.Core.Services; using Path = System.IO.Path; #pragma warning disable CS0657 // Not a valid attribute location for this declaration @@ -27,6 +25,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceImageUpscaleView), persistent: true)] [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] +[ManagedService] +[Transient] public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -60,9 +60,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase public InferenceImageUpscaleViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, + ISettingsManager settingsManager, ServiceManager vmFactory ) - : base(vmFactory, inferenceClientManager, notificationService) + : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 07124ac0..c06c438c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -26,6 +26,8 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceTextToImageView), persistent: true)] +[ManagedService] +[Transient] public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, IParametersLoadableState @@ -86,10 +88,11 @@ public class InferenceTextToImageViewModel public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, + ISettingsManager settingsManager, ServiceManager vmFactory, IModelIndexService modelIndexService ) - : base(vmFactory, inferenceClientManager, notificationService) + : base(vmFactory, inferenceClientManager, notificationService, settingsManager) { this.notificationService = notificationService; this.modelIndexService = modelIndexService; @@ -248,7 +251,7 @@ public class InferenceTextToImageViewModel if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false }) { var customVaeLoader = nodes.AddNamedNode( - ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.FileName) + ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath) ); builder.Connections.BaseVAE = customVaeLoader.Output; @@ -381,22 +384,7 @@ public class InferenceTextToImageViewModel Client = ClientManager.Client, Nodes = buildPromptArgs.Builder.ToNodeDictionary(), OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(), - Parameters = new GenerationParameters - { - Seed = (ulong)seed, - Steps = SamplerCardViewModel.Steps, - CfgScale = SamplerCardViewModel.CfgScale, - Sampler = SamplerCardViewModel.SelectedSampler?.Name, - ModelName = ModelCardViewModel.SelectedModelName, - ModelHash = ModelCardViewModel - .SelectedModel - ?.Local - ?.ConnectedModelInfo - ?.Hashes - .SHA256, - PositivePrompt = PromptCardViewModel.PromptDocument.Text, - NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text - }, + Parameters = SaveStateToParameters(new GenerationParameters()), Project = InferenceProjectDocument.FromLoadable(this), // Only clear output images on the first batch ClearOutputImages = i == 0 @@ -417,10 +405,9 @@ public class InferenceTextToImageViewModel { PromptCardViewModel.LoadStateFromParameters(parameters); SamplerCardViewModel.LoadStateFromParameters(parameters); + ModelCardViewModel.LoadStateFromParameters(parameters); SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed); - - ModelCardViewModel.LoadStateFromParameters(parameters); } /// @@ -428,11 +415,10 @@ public class InferenceTextToImageViewModel { parameters = PromptCardViewModel.SaveStateToParameters(parameters); parameters = SamplerCardViewModel.SaveStateToParameters(parameters); + parameters = ModelCardViewModel.SaveStateToParameters(parameters); parameters.Seed = (ulong)SeedCardViewModel.Seed; - parameters = ModelCardViewModel.SaveStateToParameters(parameters); - return parameters; } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs index 2c28da8f..d69f5b80 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs @@ -12,6 +12,8 @@ using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(ModelCard))] +[ManagedService] +[Transient] public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState { [ObservableProperty] @@ -29,9 +31,9 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad [ObservableProperty] private bool isVaeSelectionEnabled; - public string? SelectedModelName => SelectedModel?.FileName; + public string? SelectedModelName => SelectedModel?.RelativePath; - public string? SelectedVaeName => SelectedVae?.FileName; + public string? SelectedVaeName => SelectedVae?.RelativePath; public IInferenceClientManager ClientManager { get; } @@ -60,11 +62,11 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad SelectedModel = model.SelectedModelName is null ? null - : ClientManager.Models.FirstOrDefault(x => x.FileName == model.SelectedModelName); + : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); SelectedVae = model.SelectedVaeName is null ? HybridModelFile.Default - : ClientManager.VaeModels.FirstOrDefault(x => x.FileName == model.SelectedVaeName); + : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); } internal class ModelCardModel @@ -99,7 +101,7 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad else { // Name matches - model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName)); + model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName)); model ??= currentModels.FirstOrDefault( m => m.ShortDisplayName.StartsWith(paramsModelName) ); @@ -114,6 +116,10 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad /// public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { - return parameters with { ModelName = SelectedModel?.FileName }; + return parameters with + { + ModelName = SelectedModel?.FileName, + ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256 + }; } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs index a7b76ad0..26e8acf8 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs @@ -1,15 +1,9 @@ -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Linq; -using System.Text; +using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading.Tasks; using AvaloniaEdit; using AvaloniaEdit.Document; -using AvaloniaEdit.Editing; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using StabilityMatrix.Avalonia.Controls; @@ -20,7 +14,6 @@ using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Exceptions; -using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Services; @@ -28,6 +21,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(PromptCard))] +[ManagedService] +[Transient] public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState { private readonly IModelIndexService modelIndexService; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index 1d70fa42..8d52ebe9 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -13,6 +13,8 @@ using StabilityMatrix.Core.Models.Api.Comfy; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(SamplerCard))] +[ManagedService] +[Transient] public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState { [ObservableProperty] diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs index 6c52c4db..1f0a9fa4 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs @@ -10,6 +10,8 @@ using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(SeedCard))] +[ManagedService] +[Transient] public partial class SeedCardViewModel : LoadableViewModelBase { [ObservableProperty, NotifyPropertyChangedFor(nameof(RandomizeButtonToolTip))] diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs index ff175ccd..7639bc14 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Drawing; -using System.Text.Json; using System.Linq; using Avalonia.Input; using Avalonia.Media; @@ -19,6 +18,8 @@ using StabilityMatrix.Core.Models.Database; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(SelectImageCard))] +[ManagedService] +[Transient] public partial class SelectImageCardViewModel : ViewModelBase, IDropTarget { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -88,7 +89,7 @@ public partial class SelectImageCardViewModel : ViewModelBase, IDropTarget { var current = ImageSource; - ImageSource = new ImageSource(imageFile.GlobalFullPath); + ImageSource = new ImageSource(imageFile.AbsolutePath); current?.Dispose(); }); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs index 98a1f933..90a67a12 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs @@ -7,6 +7,8 @@ using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(SharpenCard))] +[ManagedService] +[Transient] public partial class SharpenCardViewModel : LoadableViewModelBase { [Range(1, 31)] diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs index 763054b2..49e7b810 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs @@ -8,20 +8,24 @@ using StabilityMatrix.Core.Extensions; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(StackCard))] +[ManagedService] +[Transient] public class StackCardViewModel : StackViewModelBase { /// public override void LoadStateFromJsonObject(JsonObject state) { var model = DeserializeModel(state); - - if (model.Cards is null) return; - + + if (model.Cards is null) + return; + foreach (var (i, card) in model.Cards.Enumerate()) { // Ignore if more than cards than we have - if (i > Cards.Count - 1) break; - + if (i > Cards.Count - 1) + break; + Cards[i].LoadStateFromJsonObject(card); } } @@ -29,9 +33,8 @@ public class StackCardViewModel : StackViewModelBase /// public override JsonObject SaveStateToJsonObject() { - return SerializeModel(new StackCardModel - { - Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList() - }); + return SerializeModel( + new StackCardModel { Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList() } + ); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs index 2b4e3c2e..9eef7ad6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs @@ -11,28 +11,32 @@ using StabilityMatrix.Core.Extensions; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(StackExpander))] +[ManagedService] +[Transient] public partial class StackExpanderViewModel : StackViewModelBase { [ObservableProperty] [property: JsonIgnore] private string? title; - - [ObservableProperty] + + [ObservableProperty] private bool isEnabled; - + /// public override void LoadStateFromJsonObject(JsonObject state) { var model = DeserializeModel(state); IsEnabled = model.IsEnabled; - - if (model.Cards is null) return; - + + if (model.Cards is null) + return; + foreach (var (i, card) in model.Cards.Enumerate()) { // Ignore if more than cards than we have - if (i > Cards.Count - 1) break; - + if (i > Cards.Count - 1) + break; + Cards[i].LoadStateFromJsonObject(card); } } @@ -40,10 +44,12 @@ public partial class StackExpanderViewModel : StackViewModelBase /// public override JsonObject SaveStateToJsonObject() { - return SerializeModel(new StackExpanderModel - { - IsEnabled = IsEnabled, - Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList() - }); + return SerializeModel( + new StackExpanderModel + { + IsEnabled = IsEnabled, + Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList() + } + ); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs index ceef75c6..02ef8a83 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs @@ -20,6 +20,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(UpscalerCard))] +[ManagedService] +[Transient] public partial class UpscalerCardViewModel : LoadableViewModelBase { private readonly INotificationService notificationService; diff --git a/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs index 4b2ff874..33320aab 100644 --- a/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs @@ -9,7 +9,6 @@ using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls; using Avalonia.Controls.Notifications; -using Avalonia.Controls.Shapes; using Avalonia.Platform.Storage; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; @@ -23,7 +22,6 @@ using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.Views; -using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Database; using StabilityMatrix.Core.Extensions; @@ -43,6 +41,7 @@ namespace StabilityMatrix.Avalonia.ViewModels; [Preload] [View(typeof(InferencePage))] +[Singleton] public partial class InferenceViewModel : PageViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -111,6 +110,10 @@ public partial class InferenceViewModel : PageViewModelBase // Keep RunningPackage updated with the current package pair EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged; + // "Send to Inference" + EventManager.Instance.InferenceTextToImageRequested += OnInferenceTextToImageRequested; + EventManager.Instance.InferenceUpscaleRequested += OnInferenceUpscaleRequested; + MenuSaveAsCommand.WithConditionalNotificationErrorHandler(notificationService); MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService); } @@ -232,6 +235,16 @@ public partial class InferenceViewModel : PageViewModelBase } } + private void OnInferenceTextToImageRequested(object? sender, LocalImageFile e) + { + Dispatcher.UIThread.Post(() => AddTabFromImage(e).SafeFireAndForget()); + } + + private void OnInferenceUpscaleRequested(object? sender, LocalImageFile e) + { + Dispatcher.UIThread.Post(() => AddUpscalerTabFromImage(e).SafeFireAndForget()); + } + /// /// Update the database with current tabs /// @@ -593,6 +606,70 @@ public partial class InferenceViewModel : PageViewModelBase await SyncTabStatesWithDatabase(); } + private async Task AddTabFromImage(LocalImageFile imageFile) + { + var metadata = imageFile.ReadMetadata(); + InferenceTabViewModelBase? vm = null; + + if (!string.IsNullOrWhiteSpace(metadata.SMProject)) + { + var document = JsonSerializer.Deserialize(metadata.SMProject); + if (document is null) + { + throw new ApplicationException( + "MenuOpenProject: Deserialize project file returned null" + ); + } + + if (document.State is null) + { + throw new ApplicationException("Project file does not have 'State' key"); + } + + document.VerifyVersion(); + var textToImage = vmFactory.Get(); + textToImage.LoadStateFromJsonObject(document.State); + vm = textToImage; + } + else if (!string.IsNullOrWhiteSpace(metadata.Parameters)) + { + if (GenerationParameters.TryParse(metadata.Parameters, out var generationParameters)) + { + var textToImageViewModel = vmFactory.Get(); + textToImageViewModel.LoadStateFromParameters(generationParameters); + vm = textToImageViewModel; + } + } + + if (vm == null) + { + notificationService.Show( + "Unable to load project from image", + "No image metadata found", + NotificationType.Error + ); + return; + } + + Tabs.Add(vm); + + SelectedTab = vm; + + await SyncTabStatesWithDatabase(); + } + + private async Task AddUpscalerTabFromImage(LocalImageFile imageFile) + { + var upscaleVm = vmFactory.Get(); + upscaleVm.IsUpscaleEnabled = true; + upscaleVm.SelectImageCardViewModel.ImageSource = new ImageSource(imageFile.AbsolutePath); + + Tabs.Add(upscaleVm); + SelectedTab = upscaleVm; + + await SyncTabStatesWithDatabase(); + } + /// /// Menu "Open Project" command. /// diff --git a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs index 87c7ea3f..b9ed79ab 100644 --- a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs @@ -40,6 +40,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(LaunchPageView))] +[Singleton] public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyncDisposable { private readonly ILogger logger; diff --git a/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs index b165aef0..87ebb0f5 100644 --- a/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs @@ -2,14 +2,12 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Collections.ObjectModel; -using System.IO; using System.Linq; using System.Net.Http; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls; using Avalonia.Controls.Notifications; -using AvaloniaEdit.Utils; using CommunityToolkit.Mvvm.ComponentModel; using FluentAvalonia.UI.Controls; using Microsoft.Extensions.Logging; @@ -35,6 +33,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(NewCheckpointsPage))] +[Singleton] public partial class NewCheckpointsPageViewModel : PageViewModelBase { private readonly ILogger logger; @@ -44,12 +43,17 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase private readonly ServiceManager dialogFactory; private readonly INotificationService notificationService; public override string Title => "Checkpoint Manager"; - public override IconSource IconSource => new SymbolIconSource - {Symbol = Symbol.Cellular5g, IsFilled = true}; + public override IconSource IconSource => + new SymbolIconSource { Symbol = Symbol.Cellular5g, IsFilled = true }; - public NewCheckpointsPageViewModel(ILogger logger, - ISettingsManager settingsManager, ILiteDbContext liteDbContext, ICivitApi civitApi, - ServiceManager dialogFactory, INotificationService notificationService) + public NewCheckpointsPageViewModel( + ILogger logger, + ISettingsManager settingsManager, + ILiteDbContext liteDbContext, + ICivitApi civitApi, + ServiceManager dialogFactory, + INotificationService notificationService + ) { this.logger = logger; this.settingsManager = settingsManager; @@ -63,23 +67,27 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase [NotifyPropertyChangedFor(nameof(ConnectedCheckpoints))] [NotifyPropertyChangedFor(nameof(NonConnectedCheckpoints))] private ObservableCollection allCheckpoints = new(); - + [ObservableProperty] private ObservableCollection civitModels = new(); - public ObservableCollection ConnectedCheckpoints => new( - AllCheckpoints.Where(x => x.IsConnectedModel) - .OrderBy(x => x.ConnectedModel!.ModelName) - .ThenBy(x => x.ModelType) - .GroupBy(x => x.ConnectedModel!.ModelId) - .Select(x => x.First())); + public ObservableCollection ConnectedCheckpoints => + new( + AllCheckpoints + .Where(x => x.IsConnectedModel) + .OrderBy(x => x.ConnectedModel!.ModelName) + .ThenBy(x => x.ModelType) + .GroupBy(x => x.ConnectedModel!.ModelId) + .Select(x => x.First()) + ); - public ObservableCollection NonConnectedCheckpoints => new( - AllCheckpoints.Where(x => !x.IsConnectedModel).OrderBy(x => x.ModelType)); + public ObservableCollection NonConnectedCheckpoints => + new(AllCheckpoints.Where(x => !x.IsConnectedModel).OrderBy(x => x.ModelType)); public override async Task OnLoadedAsync() { - if (Design.IsDesignMode) return; + if (Design.IsDesignMode) + return; var files = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory); AllCheckpoints = new ObservableCollection(files); @@ -89,17 +97,17 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase { CommaSeparatedModelIds = string.Join(',', connectedModelIds) }; - + // See if query is cached var cachedQuery = await liteDbContext.CivitModelQueryCache .IncludeAll() .FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest)); - + // If cached, update model cards if (cachedQuery is not null) { CivitModels = new ObservableCollection(cachedQuery.Items); - + // Start remote query (background mode) // Skip when last query was less than 2 min ago var timeSinceCache = DateTimeOffset.UtcNow - cachedQuery.InsertedAt; @@ -113,24 +121,34 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase await CivitQuery(modelRequest); } } - + public async Task ShowVersionDialog(int modelId) { var model = CivitModels.FirstOrDefault(m => m.Id == modelId); if (model == null) { - notificationService.Show(new Notification("Model has no versions available", - "This model has no versions available for download", NotificationType.Warning)); + notificationService.Show( + new Notification( + "Model has no versions available", + "This model has no versions available for download", + NotificationType.Warning + ) + ); return; } var versions = model.ModelVersions; if (versions is null || versions.Count == 0) { - notificationService.Show(new Notification("Model has no versions available", - "This model has no versions available for download", NotificationType.Warning)); + notificationService.Show( + new Notification( + "Model has no versions available", + "This model has no versions available for download", + NotificationType.Warning + ) + ); return; } - + var dialog = new BetterContentDialog { Title = model.Name, @@ -139,19 +157,21 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase IsFooterVisible = false, MaxDialogWidth = 750, }; - + var viewModel = dialogFactory.Get(); viewModel.Dialog = dialog; - viewModel.Versions = versions.Select(version => - new ModelVersionViewModel( - settingsManager.Settings.InstalledModelHashes ?? new HashSet(), version)) + viewModel.Versions = versions + .Select( + version => + new ModelVersionViewModel( + settingsManager.Settings.InstalledModelHashes ?? new HashSet(), + version + ) + ) .ToImmutableArray(); viewModel.SelectedVersionViewModel = viewModel.Versions[0]; - - dialog.Content = new SelectModelVersionDialog - { - DataContext = viewModel - }; + + dialog.Content = new SelectModelVersionDialog { DataContext = viewModel }; var result = await dialog.ShowAsync(); @@ -171,8 +191,10 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase var modelResponse = await civitApi.GetModels(request); var models = modelResponse.Items; // Filter out unknown model types and archived/taken-down models - models = models.Where(m => m.Type.ConvertTo() > 0) - .Where(m => m.Mode == null).ToList(); + models = models + .Where(m => m.Type.ConvertTo() > 0) + .Where(m => m.Mode == null) + .ToList(); // Database update calls will invoke `OnModelsUpdated` // Add to database @@ -186,7 +208,8 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase Request = request, Items = models, Metadata = modelResponse.Metadata - }); + } + ); if (cacheNew) { @@ -195,26 +218,42 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase } catch (OperationCanceledException) { - notificationService.Show(new Notification("Request to CivitAI timed out", - "Could not check for checkpoint updates. Please try again later.")); + notificationService.Show( + new Notification( + "Request to CivitAI timed out", + "Could not check for checkpoint updates. Please try again later." + ) + ); logger.LogWarning($"CivitAI query timed out ({request})"); } catch (HttpRequestException e) { - notificationService.Show(new Notification("CivitAI can't be reached right now", - "Could not check for checkpoint updates. Please try again later.")); + notificationService.Show( + new Notification( + "CivitAI can't be reached right now", + "Could not check for checkpoint updates. Please try again later." + ) + ); logger.LogWarning(e, $"CivitAI query HttpRequestException ({request})"); } catch (ApiException e) { - notificationService.Show(new Notification("CivitAI can't be reached right now", - "Could not check for checkpoint updates. Please try again later.")); + notificationService.Show( + new Notification( + "CivitAI can't be reached right now", + "Could not check for checkpoint updates. Please try again later." + ) + ); logger.LogWarning(e, $"CivitAI query ApiException ({request})"); } catch (Exception e) { - notificationService.Show(new Notification("CivitAI can't be reached right now", - $"Unknown exception during CivitAI query: {e.GetType().Name}")); + notificationService.Show( + new Notification( + "CivitAI can't be reached right now", + $"Unknown exception during CivitAI query: {e.GetType().Name}" + ) + ); logger.LogError(e, $"CivitAI query unknown exception ({request})"); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs new file mode 100644 index 00000000..30ce380e --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs @@ -0,0 +1,18 @@ +using CommunityToolkit.Mvvm.ComponentModel; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Models.Database; + +namespace StabilityMatrix.Avalonia.ViewModels.OutputsPage; + +public partial class OutputImageViewModel : ViewModelBase +{ + public LocalImageFile ImageFile { get; } + + [ObservableProperty] + private bool isSelected; + + public OutputImageViewModel(LocalImageFile imageFile) + { + ImageFile = imageFile; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs new file mode 100644 index 00000000..8352f6d5 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs @@ -0,0 +1,531 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Reactive.Linq; +using System.Threading.Tasks; +using AsyncAwaitBestPractices; +using AsyncImageLoader; +using Avalonia; +using Avalonia.Controls; +using Avalonia.Media; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using DynamicData; +using DynamicData.Binding; +using FluentAvalonia.UI.Controls; +using Microsoft.Extensions.Logging; +using Nito.Disposables.Internals; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Avalonia.ViewModels.OutputsPage; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.Factory; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Services; +using Size = StabilityMatrix.Core.Models.Settings.Size; +using Symbol = FluentIcons.Common.Symbol; +using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; + +namespace StabilityMatrix.Avalonia.ViewModels; + +[View(typeof(Views.OutputsPage))] +[Singleton] +public partial class OutputsPageViewModel : PageViewModelBase +{ + private readonly ISettingsManager settingsManager; + private readonly IPackageFactory packageFactory; + private readonly INotificationService notificationService; + private readonly INavigationService navigationService; + private readonly ILogger logger; + public override string Title => Resources.Label_OutputsPageTitle; + + public override IconSource IconSource => + new SymbolIconSource { Symbol = Symbol.Grid, IsFilled = true }; + + public SourceCache OutputsCache { get; } = + new(file => file.AbsolutePath); + + public IObservableCollection Outputs { get; set; } = + new ObservableCollectionExtended(); + + public IEnumerable OutputTypes { get; } = Enum.GetValues(); + + [ObservableProperty] + private ObservableCollection categories; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(CanShowOutputTypes))] + private PackageOutputCategory selectedCategory; + + [ObservableProperty] + private SharedOutputType selectedOutputType; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(NumImagesSelected))] + private int numItemsSelected; + + [ObservableProperty] + private string searchQuery; + + [ObservableProperty] + private Size imageSize = new(300, 300); + + [ObservableProperty] + private bool isConsolidating; + + public bool CanShowOutputTypes => + SelectedCategory?.Name?.Equals("Shared Output Folder") ?? false; + + public string NumImagesSelected => + NumItemsSelected == 1 + ? Resources.Label_OneImageSelected + : string.Format(Resources.Label_NumImagesSelected, NumItemsSelected); + + public OutputsPageViewModel( + ISettingsManager settingsManager, + IPackageFactory packageFactory, + INotificationService notificationService, + INavigationService navigationService, + ILogger logger + ) + { + this.settingsManager = settingsManager; + this.packageFactory = packageFactory; + this.notificationService = notificationService; + this.navigationService = navigationService; + this.logger = logger; + + var searcher = new ImageSearcher(); + + // Observable predicate from SearchQuery changes + var searchPredicate = this.WhenPropertyChanged(vm => vm.SearchQuery) + .Throttle(TimeSpan.FromMilliseconds(50))! + .Select(property => searcher.GetPredicate(property.Value)) + .AsObservable(); + + OutputsCache + .Connect() + .DeferUntilLoaded() + .Filter(searchPredicate) + .Transform(file => new OutputImageViewModel(file)) + .SortBy(vm => vm.ImageFile.CreatedAt, SortDirection.Descending) + .Bind(Outputs) + .WhenPropertyChanged(p => p.IsSelected) + .Subscribe(_ => + { + NumItemsSelected = Outputs.Count(o => o.IsSelected); + }); + + settingsManager.RelayPropertyFor( + this, + vm => vm.ImageSize, + settings => settings.OutputsImageSize, + delay: TimeSpan.FromMilliseconds(250) + ); + } + + public override void OnLoaded() + { + if (Design.IsDesignMode) + return; + + if (!settingsManager.IsLibraryDirSet) + return; + + Directory.CreateDirectory(settingsManager.ImagesDirectory); + + var packageCategories = settingsManager.Settings.InstalledPackages + .Where(x => !x.UseSharedOutputFolder) + .Select(packageFactory.GetPackagePair) + .WhereNotNull() + .Where( + p => + p.BasePackage.SharedOutputFolders != null + && p.BasePackage.SharedOutputFolders.Any() + ) + .Select( + pair => + new PackageOutputCategory + { + Path = Path.Combine( + pair.InstalledPackage.FullPath!, + pair.BasePackage.OutputFolderName + ), + Name = pair.InstalledPackage.DisplayName ?? "" + } + ) + .ToList(); + + packageCategories.Insert( + 0, + new PackageOutputCategory + { + Path = settingsManager.ImagesDirectory, + Name = "Shared Output Folder" + } + ); + + packageCategories.Insert( + 1, + new PackageOutputCategory + { + Path = settingsManager.ImagesInferenceDirectory, + Name = "Inference" + } + ); + + Categories = new ObservableCollection(packageCategories); + SelectedCategory = Categories.First(); + SelectedOutputType = SharedOutputType.All; + SearchQuery = string.Empty; + ImageSize = settingsManager.Settings.OutputsImageSize; + + var path = + CanShowOutputTypes && SelectedOutputType != SharedOutputType.All + ? Path.Combine(SelectedCategory.Path, SelectedOutputType.ToString()) + : SelectedCategory.Path; + GetOutputs(path); + } + + partial void OnSelectedCategoryChanged( + PackageOutputCategory? oldValue, + PackageOutputCategory? newValue + ) + { + if (oldValue == newValue || newValue == null) + return; + + var path = + CanShowOutputTypes && SelectedOutputType != SharedOutputType.All + ? Path.Combine(newValue.Path, SelectedOutputType.ToString()) + : SelectedCategory.Path; + GetOutputs(path); + } + + partial void OnSelectedOutputTypeChanged(SharedOutputType oldValue, SharedOutputType newValue) + { + if (oldValue == newValue) + return; + + var path = + newValue == SharedOutputType.All + ? SelectedCategory.Path + : Path.Combine(SelectedCategory.Path, newValue.ToString()); + GetOutputs(path); + } + + public Task OnImageClick(OutputImageViewModel item) + { + // Select image if we're in "select mode" + if (NumItemsSelected > 0) + { + item.IsSelected = !item.IsSelected; + } + else + { + return ShowImageDialog(item); + } + + return Task.CompletedTask; + } + + public async Task ShowImageDialog(OutputImageViewModel item) + { + var currentIndex = Outputs.IndexOf(item); + + var image = new ImageSource(new FilePath(item.ImageFile.AbsolutePath)); + + // Preload + await image.GetBitmapAsync(); + + var vm = new ImageViewerViewModel { ImageSource = image, LocalImageFile = item.ImageFile }; + + using var onNext = Observable + .FromEventPattern( + vm, + nameof(ImageViewerViewModel.NavigationRequested) + ) + .Subscribe(ctx => + { + Dispatcher.UIThread + .InvokeAsync(async () => + { + var sender = (ImageViewerViewModel)ctx.Sender!; + var newIndex = currentIndex + (ctx.EventArgs.IsNext ? 1 : -1); + + if (newIndex >= 0 && newIndex < Outputs.Count) + { + var newImage = Outputs[newIndex]; + var newImageSource = new ImageSource( + new FilePath(newImage.ImageFile.AbsolutePath) + ); + + // Preload + await newImageSource.GetBitmapAsync(); + + sender.ImageSource = newImageSource; + sender.LocalImageFile = newImage.ImageFile; + + currentIndex = newIndex; + } + }) + .SafeFireAndForget(); + }); + + await vm.GetDialog().ShowAsync(); + } + + public Task CopyImage(string imagePath) + { + var clipboard = App.Clipboard; + return clipboard.SetFileDataObjectAsync(imagePath); + } + + public Task OpenImage(string imagePath) => ProcessRunner.OpenFileBrowser(imagePath); + + public async Task DeleteImage(OutputImageViewModel? item) + { + if (item is null) + return; + + var confirmationDialog = new BetterContentDialog + { + Title = "Are you sure you want to delete this image?", + Content = "This action cannot be undone.", + PrimaryButtonText = Resources.Action_Delete, + SecondaryButtonText = Resources.Action_Cancel, + DefaultButton = ContentDialogButton.Primary, + IsSecondaryButtonEnabled = true, + }; + var dialogResult = await confirmationDialog.ShowAsync(); + if (dialogResult != ContentDialogResult.Primary) + return; + + // Delete the file + var imageFile = new FilePath(item.ImageFile.AbsolutePath); + var result = await notificationService.TryAsync(imageFile.DeleteAsync()); + + if (!result.IsSuccessful) + { + return; + } + + OutputsCache.Remove(item.ImageFile); + + // Invalidate cache + if (ImageLoader.AsyncImageLoader is FallbackRamCachedWebImageLoader loader) + { + loader.RemoveAllNamesFromCache(imageFile.Name); + } + } + + public void SendToTextToImage(OutputImageViewModel vm) + { + navigationService.NavigateTo(); + EventManager.Instance.OnInferenceTextToImageRequested(vm.ImageFile); + } + + public void SendToUpscale(OutputImageViewModel vm) + { + navigationService.NavigateTo(); + EventManager.Instance.OnInferenceUpscaleRequested(vm.ImageFile); + } + + public void ClearSelection() + { + foreach (var output in Outputs) + { + output.IsSelected = false; + } + } + + public void SelectAll() + { + foreach (var output in Outputs) + { + output.IsSelected = true; + } + } + + public async Task DeleteAllSelected() + { + var confirmationDialog = new BetterContentDialog + { + Title = $"Are you sure you want to delete {NumItemsSelected} images?", + Content = "This action cannot be undone.", + PrimaryButtonText = Resources.Action_Delete, + SecondaryButtonText = Resources.Action_Cancel, + DefaultButton = ContentDialogButton.Primary, + IsSecondaryButtonEnabled = true, + }; + var dialogResult = await confirmationDialog.ShowAsync(); + if (dialogResult != ContentDialogResult.Primary) + return; + + var selected = Outputs.Where(o => o.IsSelected).ToList(); + Debug.Assert(selected.Count == NumItemsSelected); + foreach (var output in selected) + { + // Delete the file + var imageFile = new FilePath(output.ImageFile.AbsolutePath); + var result = await notificationService.TryAsync(imageFile.DeleteAsync()); + + if (!result.IsSuccessful) + { + continue; + } + OutputsCache.Remove(output.ImageFile); + + // Invalidate cache + if (ImageLoader.AsyncImageLoader is FallbackRamCachedWebImageLoader loader) + { + loader.RemoveAllNamesFromCache(imageFile.Name); + } + } + + NumItemsSelected = 0; + ClearSelection(); + } + + public async Task ConsolidateImages() + { + var stackPanel = new StackPanel(); + stackPanel.Children.Add( + new TextBlock + { + Text = Resources.Label_ConsolidateExplanation, + TextWrapping = TextWrapping.Wrap, + Margin = new Thickness(0, 8, 0, 16) + } + ); + foreach (var category in Categories) + { + if (category.Name == "Shared Output Folder") + { + continue; + } + + stackPanel.Children.Add( + new CheckBox + { + Content = $"{category.Name} ({category.Path})", + IsChecked = true, + Margin = new Thickness(0, 8, 0, 0), + Tag = category.Path + } + ); + } + + var confirmationDialog = new BetterContentDialog + { + Title = Resources.Label_AreYouSure, + Content = stackPanel, + PrimaryButtonText = Resources.Action_Yes, + SecondaryButtonText = Resources.Action_Cancel, + DefaultButton = ContentDialogButton.Primary, + IsSecondaryButtonEnabled = true, + }; + + var dialogResult = await confirmationDialog.ShowAsync(); + if (dialogResult != ContentDialogResult.Primary) + return; + + IsConsolidating = true; + + Directory.CreateDirectory(settingsManager.ConsolidatedImagesDirectory); + + foreach ( + var category in stackPanel.Children.OfType().Where(c => c.IsChecked == true) + ) + { + if ( + string.IsNullOrWhiteSpace(category.Tag?.ToString()) + || !Directory.Exists(category.Tag?.ToString()) + ) + continue; + + var directory = category.Tag.ToString(); + + foreach ( + var path in Directory.EnumerateFiles( + directory, + "*.png", + SearchOption.AllDirectories + ) + ) + { + try + { + var file = new FilePath(path); + var newPath = settingsManager.ConsolidatedImagesDirectory + file.Name; + if (file.FullPath == newPath) + continue; + + // ignore inference if not in inference directory + if ( + file.FullPath.Contains(settingsManager.ImagesInferenceDirectory) + && directory != settingsManager.ImagesInferenceDirectory + ) + { + continue; + } + + await file.MoveToAsync(newPath); + } + catch (Exception e) + { + logger.LogError(e, "Error when consolidating: "); + } + } + } + + OnLoaded(); + IsConsolidating = false; + } + + private void GetOutputs(string directory) + { + if (!settingsManager.IsLibraryDirSet) + return; + + if ( + !Directory.Exists(directory) + && ( + SelectedCategory.Path != settingsManager.ImagesDirectory + || SelectedOutputType != SharedOutputType.All + ) + ) + { + Directory.CreateDirectory(directory); + return; + } + + var files = Directory + .EnumerateFiles(directory, "*.png", SearchOption.AllDirectories) + .Select(file => LocalImageFile.FromPath(file)) + .ToList(); + + if (files.Count == 0) + { + OutputsCache.Clear(); + } + else + { + OutputsCache.EditDiff(files); + } + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs index d921aa08..467f45d7 100644 --- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using Avalonia.Controls; using Avalonia.Controls.Notifications; using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using Microsoft.Extensions.Logging; using StabilityMatrix.Avalonia.Animations; @@ -14,6 +15,7 @@ using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Factory; @@ -26,6 +28,8 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.PackageManager; +[ManagedService] +[Transient] public partial class PackageCardViewModel : ProgressViewModel { private readonly ILogger logger; @@ -62,6 +66,15 @@ public partial class PackageCardViewModel : ProgressViewModel [ObservableProperty] private bool canUseConfigMethod; + [ObservableProperty] + private bool canUseSymlinkMethod; + + [ObservableProperty] + private bool useSharedOutput; + + [ObservableProperty] + private bool canUseSharedOutput; + public PackageCardViewModel( ILogger logger, IPackageFactory packageFactory, @@ -103,6 +116,11 @@ public partial class PackageCardViewModel : ProgressViewModel CanUseConfigMethod = basePackage?.AvailableSharedFolderMethods.Contains(SharedFolderMethod.Configuration) ?? false; + CanUseSymlinkMethod = + basePackage?.AvailableSharedFolderMethods.Contains(SharedFolderMethod.Symlink) + ?? false; + UseSharedOutput = Package?.UseSharedOutputFolder ?? false; + CanUseSharedOutput = basePackage?.SharedOutputFolders != null; } } @@ -243,7 +261,29 @@ public partial class PackageCardViewModel : ProgressViewModel { ModificationCompleteMessage = $"{packageName} Update Complete" }; - var updatePackageStep = new UpdatePackageStep(settingsManager, Package, basePackage); + + var versionOptions = new DownloadPackageVersionOptions { IsLatest = true }; + if (Package.Version.IsReleaseMode) + { + versionOptions.VersionTag = await basePackage.GetLatestVersion(); + } + else + { + var commits = await basePackage.GetAllCommits(Package.Version.InstalledBranch); + var latest = commits?.FirstOrDefault(); + if (latest == null) + throw new Exception("Could not find latest commit"); + + versionOptions.BranchName = Package.Version.InstalledBranch; + versionOptions.CommitHash = latest.Sha; + } + + var updatePackageStep = new UpdatePackageStep( + settingsManager, + Package, + versionOptions, + basePackage + ); var steps = new List { updatePackageStep }; EventManager.Instance.OnPackageInstallProgressAdded(runner); @@ -335,6 +375,20 @@ public partial class PackageCardViewModel : ProgressViewModel await ProcessRunner.OpenFolderBrowser(Package.FullPath); } + [RelayCommand] + public async Task OpenPythonPackagesDialog() + { + if (Package is not { FullPath: not null }) + return; + + var vm = vmFactory.Get(vm => + { + vm.VenvPath = new DirectoryPath(Package.FullPath, "venv"); + }); + + await vm.GetDialog().ShowAsync(); + } + private async Task HasUpdate() { if (Package == null || IsUnknownPackage || Design.IsDesignMode) @@ -374,6 +428,33 @@ public partial class PackageCardViewModel : ProgressViewModel public void ToggleSharedModelNone() => IsSharedModelDisabled = !IsSharedModelDisabled; + public void ToggleSharedOutput() => UseSharedOutput = !UseSharedOutput; + + partial void OnUseSharedOutputChanged(bool value) + { + if (Package == null) + return; + + if (value == Package.UseSharedOutputFolder) + return; + + using var st = settingsManager.BeginTransaction(); + Package.UseSharedOutputFolder = value; + + var basePackage = packageFactory[Package.PackageName!]; + if (basePackage == null) + return; + + if (value) + { + basePackage.SetupOutputFolderLinks(Package.FullPath!); + } + else + { + basePackage.RemoveOutputFolderLinks(Package.FullPath!); + } + } + // fake radio button stuff partial void OnIsSharedModelSymlinkChanged(bool oldValue, bool newValue) { diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs index 22c4dd10..5578a5f7 100644 --- a/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs @@ -7,9 +7,12 @@ using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls; using Avalonia.Controls.Notifications; +using Avalonia.Controls.Primitives; +using Avalonia.Threading; using DynamicData; using DynamicData.Binding; using FluentAvalonia.UI.Controls; +using Microsoft.Extensions.Logging; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; @@ -19,7 +22,6 @@ using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Avalonia.Views.Dialogs; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; -using StabilityMatrix.Core.Helper.Factory; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.PackageModification; @@ -35,12 +37,13 @@ namespace StabilityMatrix.Avalonia.ViewModels; /// [View(typeof(PackageManagerPage))] +[Singleton] public partial class PackageManagerViewModel : PageViewModelBase { private readonly ISettingsManager settingsManager; - private readonly IPackageFactory packageFactory; private readonly ServiceManager dialogFactory; private readonly INotificationService notificationService; + private readonly ILogger logger; public override string Title => "Packages"; public override IconSource IconSource => @@ -62,17 +65,19 @@ public partial class PackageManagerViewModel : PageViewModelBase public IObservableCollection PackageCards { get; } = new ObservableCollectionExtended(); + private DispatcherTimer timer; + public PackageManagerViewModel( ISettingsManager settingsManager, - IPackageFactory packageFactory, ServiceManager dialogFactory, - INotificationService notificationService + INotificationService notificationService, + ILogger logger ) { this.settingsManager = settingsManager; - this.packageFactory = packageFactory; this.dialogFactory = dialogFactory; this.notificationService = notificationService; + this.logger = logger; EventManager.Instance.InstalledPackagesChanged += OnInstalledPackagesChanged; @@ -93,6 +98,9 @@ public partial class PackageManagerViewModel : PageViewModelBase ) .Bind(PackageCards) .Subscribe(); + + timer = new DispatcherTimer { Interval = TimeSpan.FromMinutes(15), IsEnabled = true }; + timer.Tick += async (_, _) => await CheckPackagesForUpdates(); } public void SetPackages(IEnumerable packages) @@ -117,22 +125,31 @@ public partial class PackageManagerViewModel : PageViewModelBase var currentUnknown = await Task.Run(IndexUnknownPackages); unknownInstalledPackages.Edit(s => s.Load(currentUnknown)); + + timer.Start(); + } + + public override void OnUnloaded() + { + timer.Stop(); + base.OnUnloaded(); } public async Task ShowInstallDialog(BasePackage? selectedPackage = null) { var viewModel = dialogFactory.Get(); - viewModel.AvailablePackages = packageFactory.GetAllAvailablePackages().ToImmutableArray(); viewModel.SelectedPackage = selectedPackage ?? viewModel.AvailablePackages[0]; var dialog = new BetterContentDialog { MaxDialogWidth = 900, MinDialogWidth = 900, + FullSizeDesired = true, DefaultButton = ContentDialogButton.Close, IsPrimaryButtonEnabled = false, IsSecondaryButtonEnabled = false, IsFooterVisible = false, + ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled, Content = new InstallerDialog { DataContext = viewModel } }; @@ -157,6 +174,25 @@ public partial class PackageManagerViewModel : PageViewModelBase } } + private async Task CheckPackagesForUpdates() + { + foreach (var package in PackageCards) + { + try + { + await package.OnLoadedAsync(); + } + catch (Exception e) + { + logger.LogError( + e, + "Failed to check for updates for {Package}", + package?.Package?.PackageName + ); + } + } + } + private IEnumerable IndexUnknownPackages() { var packageDir = new DirectoryPath(settingsManager.LibraryDir).JoinDir("Packages"); diff --git a/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs index dc410872..9cb0eb3d 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs @@ -17,7 +17,10 @@ public class PackageInstallProgressItemViewModel : ProgressItemViewModelBase private readonly IPackageModificationRunner packageModificationRunner; private BetterContentDialog? dialog; - public PackageInstallProgressItemViewModel(IPackageModificationRunner packageModificationRunner) + public PackageInstallProgressItemViewModel( + IPackageModificationRunner packageModificationRunner, + bool hideCloseButton = false + ) { this.packageModificationRunner = packageModificationRunner; Id = packageModificationRunner.Id; @@ -25,6 +28,7 @@ public class PackageInstallProgressItemViewModel : ProgressItemViewModelBase Progress.Value = packageModificationRunner.CurrentProgress.Percentage; Progress.Text = packageModificationRunner.ConsoleOutput.LastOrDefault(); Progress.IsIndeterminate = packageModificationRunner.CurrentProgress.IsIndeterminate; + Progress.HideCloseButton = hideCloseButton; Progress.Console.StartUpdates(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs index 51b88020..445f0e0c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs @@ -22,6 +22,8 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels.Progress; [View(typeof(ProgressManagerPage))] +[ManagedService] +[Singleton] public partial class ProgressManagerViewModel : PageViewModelBase { private readonly INotificationService notificationService; @@ -120,19 +122,22 @@ public partial class ProgressManagerViewModel : PageViewModelBase } } - private async Task AddPackageInstall(IPackageModificationRunner packageModificationRunner) + private Task AddPackageInstall(IPackageModificationRunner packageModificationRunner) { if (ProgressItems.Any(vm => vm.Id == packageModificationRunner.Id)) { - return; + return Task.CompletedTask; } - var vm = new PackageInstallProgressItemViewModel(packageModificationRunner); + var vm = new PackageInstallProgressItemViewModel( + packageModificationRunner, + packageModificationRunner.HideCloseButton + ); ProgressItems.Add(vm); - if (packageModificationRunner.ShowDialogOnStart) - { - await vm.ShowProgressDialog(); - } + + return packageModificationRunner.ShowDialogOnStart + ? vm.ShowProgressDialog() + : Task.CompletedTask; } private void ShowFailedNotification(string title, string message) diff --git a/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs index 10ee1194..b624f1e6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs @@ -16,10 +16,12 @@ namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(RefreshBadge))] [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] +[ManagedService] +[Transient] public partial class RefreshBadgeViewModel : ViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - + public string WorkingToolTipText { get; set; } = "Loading..."; public string SuccessToolTipText { get; set; } = "Success"; public string InactiveToolTipText { get; set; } = ""; @@ -32,7 +34,7 @@ public partial class RefreshBadgeViewModel : ViewModelBase public IBrush SuccessColorBrush { get; set; } = ThemeColors.ThemeGreen; public IBrush InactiveColorBrush { get; set; } = ThemeColors.ThemeYellow; public IBrush FailColorBrush { get; set; } = ThemeColors.ThemeYellow; - + public Func>? RefreshFunc { get; set; } [ObservableProperty] @@ -41,7 +43,7 @@ public partial class RefreshBadgeViewModel : ViewModelBase [NotifyPropertyChangedFor(nameof(CurrentToolTip))] [NotifyPropertyChangedFor(nameof(Icon))] private ProgressState state; - + public bool IsWorking => State == ProgressState.Working; /*public ControlAppearance Appearance => State switch @@ -51,36 +53,40 @@ public partial class RefreshBadgeViewModel : ViewModelBase ProgressState.Failed => ControlAppearance.Danger, _ => ControlAppearance.Secondary };*/ - - public IBrush ColorBrush => State switch - { - ProgressState.Success => SuccessColorBrush, - ProgressState.Inactive => InactiveColorBrush, - ProgressState.Failed => FailColorBrush, - _ => Brushes.Gray - }; - public string CurrentToolTip => State switch - { - ProgressState.Working => WorkingToolTipText, - ProgressState.Success => SuccessToolTipText, - ProgressState.Inactive => InactiveToolTipText, - ProgressState.Failed => FailToolTipText, - _ => "" - }; - - public Symbol Icon => State switch - { - ProgressState.Success => SuccessIcon, - ProgressState.Failed => FailIcon, - _ => InactiveIcon - }; + public IBrush ColorBrush => + State switch + { + ProgressState.Success => SuccessColorBrush, + ProgressState.Inactive => InactiveColorBrush, + ProgressState.Failed => FailColorBrush, + _ => Brushes.Gray + }; + + public string CurrentToolTip => + State switch + { + ProgressState.Working => WorkingToolTipText, + ProgressState.Success => SuccessToolTipText, + ProgressState.Inactive => InactiveToolTipText, + ProgressState.Failed => FailToolTipText, + _ => "" + }; + + public Symbol Icon => + State switch + { + ProgressState.Success => SuccessIcon, + ProgressState.Failed => FailIcon, + _ => InactiveIcon + }; [RelayCommand] private async Task Refresh() { Logger.Info("Running refresh command..."); - if (RefreshFunc == null) return; + if (RefreshFunc == null) + return; State = ProgressState.Working; try diff --git a/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs index cec925fb..80d5c7ba 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs @@ -1,12 +1,9 @@ -using CommunityToolkit.Mvvm.ComponentModel; -using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.Views.Settings; using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.ViewModels.Settings; [View(typeof(InferenceSettingsPage))] -public partial class InferenceSettingsViewModel : ViewModelBase -{ - -} +[Singleton] +public partial class InferenceSettingsViewModel : ViewModelBase { } diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs index eee148fa..74dc145b 100644 --- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs @@ -3,15 +3,16 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Collections.ObjectModel; using System.ComponentModel; +using System.ComponentModel.DataAnnotations; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; +using System.Reactive.Linq; using System.Reflection; using System.Text; using System.Text.Json; using System.Threading.Tasks; -using AsyncAwaitBestPractices; using Avalonia; using Avalonia.Controls.Notifications; using Avalonia.Controls.Primitives; @@ -21,6 +22,7 @@ using Avalonia.Styling; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using DynamicData.Binding; using FluentAvalonia.UI.Controls; using NLog; using SkiaSharp; @@ -29,6 +31,7 @@ using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; @@ -49,6 +52,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource; namespace StabilityMatrix.Avalonia.ViewModels; [View(typeof(SettingsPage))] +[Singleton] public partial class SettingsViewModel : PageViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -107,6 +111,25 @@ public partial class SettingsViewModel : PageViewModelBase [ObservableProperty] private bool isCompletionRemoveUnderscoresEnabled = true; + [ObservableProperty] + [CustomValidation(typeof(SettingsViewModel), nameof(ValidateOutputImageFileNameFormat))] + private string? outputImageFileNameFormat; + + [ObservableProperty] + private string? outputImageFileNameFormatSample; + + public IEnumerable OutputImageFileNameFormatVars => + FileNameFormatProvider + .GetSample() + .Substitutions.Select( + kv => + new FileNameFormatVar + { + Variable = $"{{{kv.Key}}}", + Example = kv.Value.Invoke() + } + ); + [ObservableProperty] private bool isImageViewerPixelGridEnabled = true; @@ -201,6 +224,39 @@ public partial class SettingsViewModel : PageViewModelBase true ); + this.WhenPropertyChanged(vm => vm.OutputImageFileNameFormat) + .Throttle(TimeSpan.FromMilliseconds(50)) + .Subscribe(formatProperty => + { + var provider = FileNameFormatProvider.GetSample(); + var template = formatProperty.Value ?? string.Empty; + + if ( + !string.IsNullOrEmpty(template) + && provider.Validate(template) == ValidationResult.Success + ) + { + var format = FileNameFormat.Parse(template, provider); + OutputImageFileNameFormatSample = format.GetFileName() + ".png"; + } + else + { + // Use default format if empty + var defaultFormat = FileNameFormat.Parse( + FileNameFormat.DefaultTemplate, + provider + ); + OutputImageFileNameFormatSample = defaultFormat.GetFileName() + ".png"; + } + }); + + settingsManager.RelayPropertyFor( + this, + vm => vm.OutputImageFileNameFormat, + settings => settings.InferenceOutputImageFileNameFormat, + true + ); + settingsManager.RelayPropertyFor( this, vm => vm.IsImageViewerPixelGridEnabled, @@ -225,6 +281,14 @@ public partial class SettingsViewModel : PageViewModelBase UpdateAvailableTagCompletionCsvs(); } + public static ValidationResult ValidateOutputImageFileNameFormat( + string? format, + ValidationContext context + ) + { + return FileNameFormatProvider.GetSample().Validate(format ?? string.Empty); + } + partial void OnSelectedThemeChanged(string? value) { // In case design / tests diff --git a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml index 7ad45f0c..e14edf4a 100644 --- a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml +++ b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml @@ -21,7 +21,7 @@ @@ -66,6 +66,7 @@ - + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml.cs new file mode 100644 index 00000000..197a94ad --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml.cs @@ -0,0 +1,13 @@ +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Core.Attributes; + +namespace StabilityMatrix.Avalonia.Views.Dialogs; + +[Transient] +public partial class PythonPackagesDialog : UserControlBase +{ + public PythonPackagesDialog() + { + InitializeComponent(); + } +} diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs index 3e874a83..f9eb3bf1 100644 --- a/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs @@ -1,8 +1,10 @@ using Avalonia.Markup.Xaml; using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Views.Dialogs; +[Transient] public partial class SelectDataDirectoryDialog : UserControlBase { public SelectDataDirectoryDialog() diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs index 33d2b218..94ac9c31 100644 --- a/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs @@ -1,8 +1,10 @@ using Avalonia.Controls; using Avalonia.Markup.Xaml; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Views.Dialogs; +[Transient] public partial class UpdateDialog : UserControl { public UpdateDialog() @@ -14,4 +16,4 @@ public partial class UpdateDialog : UserControl { AvaloniaXamlLoader.Load(this); } -} \ No newline at end of file +} diff --git a/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs b/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs index f02ea254..e3a9cf7f 100644 --- a/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs @@ -1,13 +1,13 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using Avalonia; +using System.Diagnostics.CodeAnalysis; using Avalonia.Interactivity; using Avalonia.Markup.Xaml; using FluentAvalonia.UI.Controls; using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Views; +[Singleton] public partial class FirstLaunchSetupWindow : AppWindowBase { public ContentDialogResult Result { get; private set; } diff --git a/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs b/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs index eaa6a883..0e13e9d4 100644 --- a/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs @@ -1,7 +1,9 @@ using StabilityMatrix.Avalonia.Controls.Dock; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Views.Inference; +[Transient] public partial class InferenceImageUpscaleView : DockUserControlBase { public InferenceImageUpscaleView() diff --git a/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs b/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs index 11784fe4..d9116275 100644 --- a/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs @@ -1,7 +1,9 @@ using StabilityMatrix.Avalonia.Controls.Dock; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Views.Inference; +[Transient] public partial class InferenceTextToImageView : DockUserControlBase { public InferenceTextToImageView() diff --git a/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs b/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs index 7dd4e5ed..05e4eb76 100644 --- a/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs @@ -7,9 +7,11 @@ using FluentAvalonia.UI.Controls; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.ViewModels; +using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Views; +[Singleton] public partial class InferencePage : UserControlBase { private Button? _addButton; diff --git a/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml b/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml index e77a3cfa..cd81a8a8 100644 --- a/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml +++ b/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml @@ -35,7 +35,7 @@ Grid.Row="0" Grid.Column="0"> + + + + + + public async Task MoveToAsync(FilePath destinationFile) { - await Task.Run(() => Info.MoveTo(destinationFile.FullPath)).ConfigureAwait(false); + await Task.Run(() => + { + var path = destinationFile.FullPath; + if (destinationFile.Exists) + { + var num = Random.Shared.NextInt64(0, 10000); + path = path.Replace( + destinationFile.NameWithoutExtension, + $"{destinationFile.NameWithoutExtension}_{num}" + ); + } + + Info.MoveTo(path); + }) + .ConfigureAwait(false); // Return the new path return destinationFile; } diff --git a/StabilityMatrix.Core/Models/GenerationParameters.cs b/StabilityMatrix.Core/Models/GenerationParameters.cs index fab9f793..c0b5aa9b 100644 --- a/StabilityMatrix.Core/Models/GenerationParameters.cs +++ b/StabilityMatrix.Core/Models/GenerationParameters.cs @@ -1,4 +1,6 @@ -using System.Diagnostics.CodeAnalysis; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Text.RegularExpressions; using StabilityMatrix.Core.Models.Api.Comfy; @@ -30,12 +32,26 @@ public partial record GenerationParameters return false; } + try + { + generationParameters = Parse(text); + } + catch (Exception) + { + generationParameters = null; + return false; + } + + return true; + } + + public static GenerationParameters Parse(string text) + { var lines = text.Split('\n'); if (lines.LastOrDefault() is not { } lastLine) { - generationParameters = null; - return false; + throw new ValidationException("Fields line not found"); } if (lastLine.StartsWith("Steps:") != true) @@ -45,47 +61,84 @@ public partial record GenerationParameters if (lastLine.StartsWith("Steps:") != true) { - generationParameters = null; - return false; + throw new ValidationException("Unable to locate starting marker of last line"); } } // Join lines before last line, split at 'Negative prompt: ' - var joinedLines = string.Join("\n", lines[..^1]); + var joinedLines = string.Join("\n", lines[..^1]).Trim(); - var splitFirstPart = joinedLines.Split("Negative prompt: "); - if (splitFirstPart.Length != 2) - { - generationParameters = null; - return false; - } + var splitFirstPart = joinedLines.Split("Negative prompt: ", 2); - var positivePrompt = splitFirstPart[0]; - var negativePrompt = splitFirstPart[1]; + var positivePrompt = splitFirstPart.ElementAtOrDefault(0)?.Trim(); + var negativePrompt = splitFirstPart.ElementAtOrDefault(1)?.Trim(); // Parse last line - var match = ParseLastLineRegex().Match(lastLine); - if (!match.Success) - { - generationParameters = null; - return false; - } + var lineFields = ParseLine(lastLine); - generationParameters = new GenerationParameters + var generationParameters = new GenerationParameters { PositivePrompt = positivePrompt, NegativePrompt = negativePrompt, - Steps = int.Parse(match.Groups["Steps"].Value), - Sampler = match.Groups["Sampler"].Value, - CfgScale = double.Parse(match.Groups["CfgScale"].Value), - Seed = ulong.Parse(match.Groups["Seed"].Value), - Height = int.Parse(match.Groups["Height"].Value), - Width = int.Parse(match.Groups["Width"].Value), - ModelHash = match.Groups["ModelHash"].Value, - ModelName = match.Groups["ModelName"].Value, + Steps = int.Parse(lineFields.GetValueOrDefault("Steps", "0")), + Sampler = lineFields.GetValueOrDefault("Sampler"), + CfgScale = double.Parse(lineFields.GetValueOrDefault("CFG scale", "0")), + Seed = ulong.Parse(lineFields.GetValueOrDefault("Seed", "0")), + ModelHash = lineFields.GetValueOrDefault("Model hash"), + ModelName = lineFields.GetValueOrDefault("Model"), }; - return true; + if (lineFields.GetValueOrDefault("Size") is { } size) + { + var split = size.Split('x', 2); + if (split.Length == 2) + { + generationParameters = generationParameters with + { + Width = int.Parse(split[0]), + Height = int.Parse(split[1]) + }; + } + } + + return generationParameters; + } + + /// + /// Parse A1111 metadata fields in a single line where + /// fields are separated by commas and key-value pairs are separated by colons. + /// i.e. "key1: value1, key2: value2" + /// + internal static Dictionary ParseLine(string fields) + { + var dict = new Dictionary(); + + // Values main contain commas or colons + foreach (var match in ParametersFieldsRegex().Matches(fields).Cast()) + { + if (!match.Success) + continue; + + var key = match.Groups[1].Value.Trim(); + var value = UnquoteValue(match.Groups[2].Value.Trim()); + + dict.Add(key, value); + } + + return dict; + } + + /// + /// Unquotes a quoted value field if required + /// + private static string UnquoteValue(string quotedField) + { + if (!(quotedField.StartsWith('"') && quotedField.EndsWith('"'))) + { + return quotedField; + } + + return JsonNode.Parse(quotedField)?.GetValue() ?? ""; } /// @@ -126,9 +179,26 @@ public partial record GenerationParameters return (sampler, scheduler); } - // Example: Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1 - [GeneratedRegex( - """^Steps: (?\d+), Sampler: (?.+?), CFG scale: (?\d+(\.\d+)?), Seed: (?\d+), Size: (?\d+)x(?\d+), Model hash: (?.+?), Model: (?.+)$""" - )] - private static partial Regex ParseLastLineRegex(); + /// + /// Return a sample parameters for UI preview + /// + public static GenerationParameters GetSample() + { + return new GenerationParameters + { + PositivePrompt = "(cat:1.2), by artist, detailed, [shaded]", + NegativePrompt = "blurry, jpg artifacts", + Steps = 30, + CfgScale = 7, + Width = 640, + Height = 896, + Seed = 124825529, + ModelName = "ExampleMix7", + ModelHash = "b899d188a1ac7356bfb9399b2277d5b21712aa360f8f9514fba6fcce021baff7", + Sampler = "DPM++ 2M Karras" + }; + } + + [GeneratedRegex("""\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)""")] + private static partial Regex ParametersFieldsRegex(); } diff --git a/StabilityMatrix.Core/Models/HybridModelFile.cs b/StabilityMatrix.Core/Models/HybridModelFile.cs index 2f178fd0..60fd8d4f 100644 --- a/StabilityMatrix.Core/Models/HybridModelFile.cs +++ b/StabilityMatrix.Core/Models/HybridModelFile.cs @@ -30,7 +30,10 @@ public record HybridModelFile public bool IsRemote => RemoteName != null; [JsonIgnore] - public string FileName => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder; + public string RelativePath => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder; + + [JsonIgnore] + public string FileName => Path.GetFileName(RelativePath); [JsonIgnore] public string ShortDisplayName @@ -47,7 +50,7 @@ public record HybridModelFile return "Default"; } - return Path.GetFileNameWithoutExtension(FileName); + return Path.GetFileNameWithoutExtension(RelativePath); } } @@ -63,7 +66,7 @@ public record HybridModelFile public string GetId() { - return $"{FileName};{IsNone};{IsDefault}"; + return $"{RelativePath};{IsNone};{IsDefault}"; } private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer @@ -79,14 +82,14 @@ public record HybridModelFile if (x.GetType() != y.GetType()) return false; - return Equals(x.FileName, y.FileName) + return Equals(x.RelativePath, y.RelativePath) && x.IsNone == y.IsNone && x.IsDefault == y.IsDefault; } public int GetHashCode(HybridModelFile obj) { - return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.FileName); + return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.RelativePath); } } diff --git a/StabilityMatrix.Core/Models/InstalledPackage.cs b/StabilityMatrix.Core/Models/InstalledPackage.cs index 4695bb12..557d07ab 100644 --- a/StabilityMatrix.Core/Models/InstalledPackage.cs +++ b/StabilityMatrix.Core/Models/InstalledPackage.cs @@ -50,6 +50,7 @@ public class InstalledPackage : IJsonOnDeserialized public bool UpdateAvailable { get; set; } public TorchVersion? PreferredTorchVersion { get; set; } public SharedFolderMethod? PreferredSharedFolderMethod { get; set; } + public bool UseSharedOutputFolder { get; set; } /// /// Get the launch args host option value. diff --git a/StabilityMatrix.Core/Models/PackageDifficulty.cs b/StabilityMatrix.Core/Models/PackageDifficulty.cs new file mode 100644 index 00000000..51eacc46 --- /dev/null +++ b/StabilityMatrix.Core/Models/PackageDifficulty.cs @@ -0,0 +1,12 @@ +namespace StabilityMatrix.Core.Models; + +public enum PackageDifficulty +{ + Recommended = 0, + Simple = 1, + Advanced = 2, + Expert = 3, + Nightmare = 4, + UltraNightmare = 5, + Impossible = 999 +} diff --git a/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs b/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs index 6c6f1e9c..6037b106 100644 --- a/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs +++ b/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs @@ -12,6 +12,7 @@ public interface IPackageModificationRunner List ConsoleOutput { get; } Guid Id { get; } bool ShowDialogOnStart { get; init; } + bool HideCloseButton { get; init; } string? ModificationCompleteMessage { get; init; } bool Failed { get; set; } } diff --git a/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs b/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs index b307e161..372fdb5f 100644 --- a/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs +++ b/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs @@ -8,12 +8,19 @@ public class InstallPackageStep : IPackageStep { private readonly BasePackage package; private readonly TorchVersion torchVersion; + private readonly DownloadPackageVersionOptions versionOptions; private readonly string installPath; - public InstallPackageStep(BasePackage package, TorchVersion torchVersion, string installPath) + public InstallPackageStep( + BasePackage package, + TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, + string installPath + ) { this.package = package; this.torchVersion = torchVersion; + this.versionOptions = versionOptions; this.installPath = installPath; } @@ -25,9 +32,9 @@ public class InstallPackageStep : IPackageStep } await package - .InstallPackage(installPath, torchVersion, progress, OnConsoleOutput) + .InstallPackage(installPath, torchVersion, versionOptions, progress, OnConsoleOutput) .ConfigureAwait(false); } - public string ProgressTitle => "Installing package..."; + public string ProgressTitle => $"Installing {package.DisplayName}..."; } diff --git a/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs b/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs index c546671e..0adbe468 100644 --- a/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs +++ b/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs @@ -54,6 +54,7 @@ public class PackageModificationRunner : IPackageModificationRunner IsRunning = false; } + public bool HideCloseButton { get; init; } public string? ModificationCompleteMessage { get; init; } public bool ShowDialogOnStart { get; init; } diff --git a/StabilityMatrix.Core/Models/PackageModification/PipStep.cs b/StabilityMatrix.Core/Models/PackageModification/PipStep.cs new file mode 100644 index 00000000..6d5d0cfa --- /dev/null +++ b/StabilityMatrix.Core/Models/PackageModification/PipStep.cs @@ -0,0 +1,45 @@ +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Models.PackageModification; + +public class PipStep : IPackageStep +{ + public required ProcessArgs Args { get; init; } + public required DirectoryPath VenvDirectory { get; init; } + + public DirectoryPath? WorkingDirectory { get; init; } + + public IReadOnlyDictionary? EnvironmentVariables { get; init; } + + /// + public string ProgressTitle => + Args switch + { + _ when Args.Contains("install") => "Installing Pip Packages", + _ when Args.Contains("uninstall") => "Uninstalling Pip Packages", + _ when Args.Contains("-U") || Args.Contains("--upgrade") => "Updating Pip Packages", + _ => "Running Pip" + }; + + /// + public async Task ExecuteAsync(IProgress? progress = null) + { + await using var venvRunner = new PyVenvRunner(VenvDirectory) + { + WorkingDirectory = WorkingDirectory, + EnvironmentVariables = EnvironmentVariables + }; + + var args = new List { "-m", "pip" }; + args.AddRange(Args.ToArray()); + + venvRunner.RunDetached(args.ToArray(), progress.AsProcessOutputHandler()); + + await ProcessRunner.WaitForExitConditionAsync(venvRunner.Process).ConfigureAwait(false); + } +} diff --git a/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs b/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs index e7fbfa97..be996cfe 100644 --- a/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs +++ b/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs @@ -9,16 +9,19 @@ public class UpdatePackageStep : IPackageStep { private readonly ISettingsManager settingsManager; private readonly InstalledPackage installedPackage; + private readonly DownloadPackageVersionOptions versionOptions; private readonly BasePackage basePackage; public UpdatePackageStep( ISettingsManager settingsManager, InstalledPackage installedPackage, + DownloadPackageVersionOptions versionOptions, BasePackage basePackage ) { this.settingsManager = settingsManager; this.installedPackage = installedPackage; + this.versionOptions = versionOptions; this.basePackage = basePackage; } @@ -33,7 +36,13 @@ public class UpdatePackageStep : IPackageStep } var updateResult = await basePackage - .Update(installedPackage, torchVersion, progress, onConsoleOutput: OnConsoleOutput) + .Update( + installedPackage, + torchVersion, + versionOptions, + progress, + onConsoleOutput: OnConsoleOutput + ) .ConfigureAwait(false); settingsManager.UpdatePackageVersionNumber(installedPackage.Id, updateResult); diff --git a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs index 7f729d91..50defb6c 100644 --- a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs +++ b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs @@ -2,6 +2,7 @@ using System.Text.Json.Nodes; using System.Text.RegularExpressions; using NLog; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.FileInterfaces; @@ -12,6 +13,7 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class A3WebUI : BaseGitPackage { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -29,6 +31,8 @@ public class A3WebUI : BaseGitPackage new("https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png"); public string RelativeArgsDefinitionScriptPath => "modules.cmd_args"; + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Recommended; + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; public A3WebUI( @@ -61,6 +65,17 @@ public class A3WebUI : BaseGitPackage [SharedFolderType.AfterDetailer] = new[] { "models/adetailer" } }; + public override Dictionary>? SharedOutputFolders => + new() + { + [SharedOutputType.Extras] = new[] { "outputs/extras-images" }, + [SharedOutputType.Saved] = new[] { "log/images" }, + [SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" }, + [SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" }, + [SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" }, + [SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" } + }; + [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] public override List LaunchOptions => new() @@ -157,7 +172,7 @@ public class A3WebUI : BaseGitPackage new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm }; + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm }; public override async Task GetLatestVersion() { @@ -165,15 +180,16 @@ public class A3WebUI : BaseGitPackage return release.TagName!; } + public override string OutputFolderName => "outputs"; + public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) { - await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); - progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); // Setup venv await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); @@ -191,15 +207,14 @@ public class A3WebUI : BaseGitPackage case TorchVersion.Rocm: await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); break; - case TorchVersion.DirectMl: - await InstallDirectMlTorch(venvRunner, progress, onConsoleOutput) - .ConfigureAwait(false); - break; default: throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null); } - await venvRunner.PipInstall("httpx==0.24.1", onConsoleOutput); + if (versionOptions.VersionTag?.Contains("1.6.0") ?? false) + { + await venvRunner.PipInstall("httpx==0.24.1", onConsoleOutput); + } // Install requirements file progress?.Report( @@ -212,10 +227,6 @@ public class A3WebUI : BaseGitPackage .PipInstallFromRequirements(requirements, onConsoleOutput, excludes: "torch") .ConfigureAwait(false); - progress?.Report( - new ProgressReport(1f, "Installing Package Requirements", isIndeterminate: false) - ); - progress?.Report(new ProgressReport(-1f, "Updating configuration", isIndeterminate: true)); // Create and add {"show_progress_type": "TAESD"} to config.json @@ -273,7 +284,13 @@ public class A3WebUI : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.1.1"), + onConsoleOutput + ) .ConfigureAwait(false); } } diff --git a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs index dab562cc..b5984cc8 100644 --- a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs +++ b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs @@ -170,31 +170,45 @@ public abstract class BaseGitPackage : BasePackage IProgress? progress = null ) { - var downloadUrl = GetDownloadUrl(versionOptions); - - if (!Directory.Exists(DownloadLocation.Replace($"{Name}.zip", ""))) + if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag)) + { + await PrerequisiteHelper + .RunGit( + null, + null, + "clone", + "--branch", + versionOptions.VersionTag, + GithubUrl, + $"\"{installLocation}\"" + ) + .ConfigureAwait(false); + } + else if (!string.IsNullOrWhiteSpace(versionOptions.BranchName)) { - Directory.CreateDirectory(DownloadLocation.Replace($"{Name}.zip", "")); + await PrerequisiteHelper + .RunGit( + null, + null, + "clone", + "--branch", + versionOptions.BranchName, + GithubUrl, + $"\"{installLocation}\"" + ) + .ConfigureAwait(false); } - await DownloadService - .DownloadToFileAsync(downloadUrl, DownloadLocation, progress: progress) - .ConfigureAwait(false); + if (!versionOptions.IsLatest && !string.IsNullOrWhiteSpace(versionOptions.CommitHash)) + { + await PrerequisiteHelper + .RunGit(installLocation, null, "checkout", versionOptions.CommitHash) + .ConfigureAwait(false); + } progress?.Report(new ProgressReport(100, message: "Download Complete")); } - public override async Task InstallPackage( - string installLocation, - TorchVersion torchVersion, - IProgress? progress = null, - Action? onConsoleOutput = null - ) - { - await UnzipPackage(installLocation, progress).ConfigureAwait(false); - File.Delete(DownloadLocation); - } - protected Task UnzipPackage(string installLocation, IProgress? progress = null) { using var zip = ZipFile.OpenRead(DownloadLocation); @@ -279,6 +293,7 @@ public abstract class BaseGitPackage : BasePackage public override async Task Update( InstalledPackage installedPackage, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, bool includePrerelease = false, Action? onConsoleOutput = null @@ -287,47 +302,146 @@ public abstract class BaseGitPackage : BasePackage if (installedPackage.Version == null) throw new NullReferenceException("Version is null"); - if (installedPackage.Version.IsReleaseMode) + if (!Directory.Exists(Path.Combine(installedPackage.FullPath!, ".git"))) { - var releases = await GetAllReleases().ConfigureAwait(false); - var latestRelease = releases.First(x => includePrerelease || !x.Prerelease); + Logger.Info("not a git repo, initializing..."); + progress?.Report( + new ProgressReport(-1f, "Initializing git repo", isIndeterminate: true) + ); + await PrerequisiteHelper + .RunGit(installedPackage.FullPath!, onConsoleOutput, "init") + .ConfigureAwait(false); + await PrerequisiteHelper + .RunGit( + installedPackage.FullPath!, + onConsoleOutput, + "remote", + "add", + "origin", + GithubUrl + ) + .ConfigureAwait(false); + } - await DownloadPackage( - installedPackage.FullPath, - new DownloadPackageVersionOptions { VersionTag = latestRelease.TagName }, - progress + if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag)) + { + progress?.Report(new ProgressReport(-1f, "Fetching tags...", isIndeterminate: true)); + await PrerequisiteHelper + .RunGit(installedPackage.FullPath!, onConsoleOutput, "fetch", "--tags") + .ConfigureAwait(false); + + progress?.Report( + new ProgressReport( + -1f, + $"Checking out {versionOptions.VersionTag}", + isIndeterminate: true + ) + ); + await PrerequisiteHelper + .RunGit( + installedPackage.FullPath!, + onConsoleOutput, + "checkout", + versionOptions.VersionTag, + "--force" ) .ConfigureAwait(false); - await InstallPackage(installedPackage.FullPath, torchVersion, progress, onConsoleOutput) + await InstallPackage( + installedPackage.FullPath!, + torchVersion, + new DownloadPackageVersionOptions + { + VersionTag = versionOptions.VersionTag, + IsLatest = versionOptions.IsLatest + }, + progress, + onConsoleOutput + ) .ConfigureAwait(false); - return new InstalledPackageVersion { InstalledReleaseVersion = latestRelease.TagName }; + return new InstalledPackageVersion + { + InstalledReleaseVersion = versionOptions.VersionTag + }; } - // Commit mode - var allCommits = await GetAllCommits(installedPackage.Version.InstalledBranch) + // fetch + progress?.Report(new ProgressReport(-1f, "Fetching data...", isIndeterminate: true)); + await PrerequisiteHelper + .RunGit(installedPackage.FullPath!, onConsoleOutput, "fetch") .ConfigureAwait(false); - var latestCommit = allCommits?.First(); - if (latestCommit is null || string.IsNullOrEmpty(latestCommit.Sha)) + if (versionOptions.IsLatest) { - throw new Exception("No commits found for branch"); + // checkout + progress?.Report( + new ProgressReport( + -1f, + $"Checking out {installedPackage.Version.InstalledBranch}...", + isIndeterminate: true + ) + ); + await PrerequisiteHelper + .RunGit( + installedPackage.FullPath!, + onConsoleOutput, + "checkout", + versionOptions.BranchName, + "--force" + ) + .ConfigureAwait(false); + + // pull + progress?.Report(new ProgressReport(-1f, "Pulling changes...", isIndeterminate: true)); + await PrerequisiteHelper + .RunGit( + installedPackage.FullPath!, + onConsoleOutput, + "pull", + "origin", + installedPackage.Version.InstalledBranch + ) + .ConfigureAwait(false); + } + else + { + // checkout + progress?.Report( + new ProgressReport( + -1f, + $"Checking out {installedPackage.Version.InstalledBranch}...", + isIndeterminate: true + ) + ); + await PrerequisiteHelper + .RunGit( + installedPackage.FullPath!, + onConsoleOutput, + "checkout", + versionOptions.CommitHash, + "--force" + ) + .ConfigureAwait(false); } - await DownloadPackage( + await InstallPackage( installedPackage.FullPath, - new DownloadPackageVersionOptions { CommitHash = latestCommit.Sha }, - progress + torchVersion, + new DownloadPackageVersionOptions + { + CommitHash = versionOptions.CommitHash, + IsLatest = versionOptions.IsLatest + }, + progress, + onConsoleOutput ) .ConfigureAwait(false); - await InstallPackage(installedPackage.FullPath, torchVersion, progress, onConsoleOutput) - .ConfigureAwait(false); return new InstalledPackageVersion { - InstalledBranch = installedPackage.Version.InstalledBranch, - InstalledCommitSha = latestCommit.Sha + InstalledBranch = versionOptions.BranchName, + InstalledCommitSha = versionOptions.CommitHash }; } @@ -372,7 +486,37 @@ public abstract class BaseGitPackage : BasePackage { if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink) { - StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(this, installDirectory); + StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage( + SharedFolders, + installDirectory + ); + } + return Task.CompletedTask; + } + + public override Task SetupOutputFolderLinks(DirectoryPath installDirectory) + { + if (SharedOutputFolders is { } sharedOutputFolders) + { + return StabilityMatrix.Core.Helper.SharedFolders.UpdateLinksForPackage( + sharedOutputFolders, + SettingsManager.ImagesDirectory, + installDirectory, + recursiveDelete: true + ); + } + + return Task.CompletedTask; + } + + public override Task RemoveOutputFolderLinks(DirectoryPath installDirectory) + { + if (SharedOutputFolders is { } sharedOutputFolders) + { + StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage( + sharedOutputFolders, + installDirectory + ); } return Task.CompletedTask; } diff --git a/StabilityMatrix.Core/Models/Packages/BasePackage.cs b/StabilityMatrix.Core/Models/Packages/BasePackage.cs index d9fd2289..6165c697 100644 --- a/StabilityMatrix.Core/Models/Packages/BasePackage.cs +++ b/StabilityMatrix.Core/Models/Packages/BasePackage.cs @@ -38,6 +38,14 @@ public abstract class BasePackage public virtual bool IsInferenceCompatible => false; + public abstract string OutputFolderName { get; } + + public abstract IEnumerable AvailableTorchVersions { get; } + + public virtual bool IsCompatible => GetRecommendedTorchVersion() != TorchVersion.Cpu; + + public abstract PackageDifficulty InstallerSortOrder { get; } + public abstract Task DownloadPackage( string installLocation, DownloadPackageVersionOptions versionOptions, @@ -47,6 +55,7 @@ public abstract class BasePackage public abstract Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ); @@ -63,6 +72,7 @@ public abstract class BasePackage public abstract Task Update( InstalledPackage installedPackage, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, bool includePrerelease = false, Action? onConsoleOutput = null @@ -93,7 +103,8 @@ public abstract class BasePackage SharedFolderMethod sharedFolderMethod ); - public abstract IEnumerable AvailableTorchVersions { get; } + public abstract Task SetupOutputFolderLinks(DirectoryPath installDirectory); + public abstract Task RemoveOutputFolderLinks(DirectoryPath installDirectory); public virtual TorchVersion GetRecommendedTorchVersion() { @@ -142,7 +153,11 @@ public abstract class BasePackage /// The shared folders that this package supports. /// Mapping of to the relative paths from the package root. /// - public virtual Dictionary>? SharedFolders { get; } + public abstract Dictionary>? SharedFolders { get; } + public abstract Dictionary< + SharedOutputType, + IReadOnlyList + >? SharedOutputFolders { get; } public abstract Task GetLatestVersion(); public abstract Task GetAllVersionOptions(); @@ -176,9 +191,15 @@ public abstract class BasePackage ); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision("==0.15.2") + .WithXFormers("==0.0.20") + .WithTorchExtraIndex("cu118"), + onConsoleOutput + ) .ConfigureAwait(false); - await venvRunner.PipInstall("xformers==0.0.20", onConsoleOutput).ConfigureAwait(false); } protected Task InstallDirectMlTorch( @@ -191,7 +212,7 @@ public abstract class BasePackage new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true) ); - return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput); + return venvRunner.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput); } protected Task InstallCpuTorch( @@ -204,6 +225,9 @@ public abstract class BasePackage new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true) ); - return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, onConsoleOutput); + return venvRunner.PipInstall( + new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(), + onConsoleOutput + ); } } diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index af3881fb..c81055a3 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -1,6 +1,7 @@ using System.Diagnostics; using System.Text.RegularExpressions; using NLog; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.FileInterfaces; @@ -14,6 +15,7 @@ using YamlDotNet.Serialization.NamingConventions; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class ComfyUI : BaseGitPackage { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -30,6 +32,9 @@ public class ComfyUI : BaseGitPackage new("https://github.com/comfyanonymous/ComfyUI/raw/master/comfyui_screenshot.png"); public override bool ShouldIgnoreReleases => true; public override bool IsInferenceCompatible => true; + public override string OutputFolderName => "output"; + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced; + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Configuration; @@ -58,6 +63,9 @@ public class ComfyUI : BaseGitPackage [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" }, }; + public override Dictionary>? SharedOutputFolders => + new() { [SharedOutputType.Text2Img] = new[] { "output" } }; + public override List LaunchOptions => new List { @@ -141,17 +149,23 @@ public class ComfyUI : BaseGitPackage public override Task GetLatestVersion() => Task.FromResult("master"); public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm }; + new[] + { + TorchVersion.Cpu, + TorchVersion.Cuda, + TorchVersion.DirectMl, + TorchVersion.Rocm, + TorchVersion.Mps + }; public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) { - await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); - progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true)); // Setup venv await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); @@ -165,13 +179,36 @@ public class ComfyUI : BaseGitPackage await InstallCpuTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); break; case TorchVersion.Cuda: - await InstallCudaTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + await venvRunner + .PipInstall( + new PipInstallArgs() + .WithTorch("~=2.1.0") + .WithTorchVision() + .WithXFormers("==0.0.22.post4") + .AddArg("--upgrade") + .WithTorchExtraIndex("cu121"), + onConsoleOutput + ) + .ConfigureAwait(false); + break; + case TorchVersion.DirectMl: + await venvRunner + .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput) + .ConfigureAwait(false); break; case TorchVersion.Rocm: await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); break; - case TorchVersion.DirectMl: - await InstallDirectMlTorch(venvRunner, progress, onConsoleOutput) + case TorchVersion.Mps: + await venvRunner + .PipInstall( + new PipInstallArgs() + .AddArg("--pre") + .WithTorch() + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu"), + onConsoleOutput + ) .ConfigureAwait(false); break; default: @@ -441,7 +478,13 @@ public class ComfyUI : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.6"), + onConsoleOutput + ) .ConfigureAwait(false); } diff --git a/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs b/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs index ea1e34b5..54eba855 100644 --- a/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs +++ b/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs @@ -1,6 +1,7 @@ using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Services; @@ -30,6 +31,20 @@ public class DankDiffusion : BaseGitPackage public override Uri PreviewImageUri { get; } + public override string OutputFolderName { get; } + public override PackageDifficulty InstallerSortOrder { get; } + + public override Task InstallPackage( + string installLocation, + TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, + IProgress? progress = null, + Action? onConsoleOutput = null + ) + { + throw new NotImplementedException(); + } + public override Task RunPackage( string installedPackagePath, string command, @@ -68,6 +83,12 @@ public class DankDiffusion : BaseGitPackage public override List LaunchOptions { get; } + public override Dictionary>? SharedFolders { get; } + public override Dictionary< + SharedOutputType, + IReadOnlyList + >? SharedOutputFolders { get; } + public override Task GetLatestVersion() { throw new NotImplementedException(); diff --git a/StabilityMatrix.Core/Models/Packages/FocusControlNet.cs b/StabilityMatrix.Core/Models/Packages/FocusControlNet.cs new file mode 100644 index 00000000..b86d0a44 --- /dev/null +++ b/StabilityMatrix.Core/Models/Packages/FocusControlNet.cs @@ -0,0 +1,37 @@ +using System.Diagnostics; +using System.Text.RegularExpressions; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Models.Packages; + +[Singleton(typeof(BasePackage))] +public class FocusControlNet : Fooocus +{ + public FocusControlNet( + IGithubApiCache githubApi, + ISettingsManager settingsManager, + IDownloadService downloadService, + IPrerequisiteHelper prerequisiteHelper + ) + : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } + + public override string Name => "Fooocus-ControlNet-SDXL"; + public override string DisplayName { get; set; } = "Fooocus-ControlNet"; + public override string Author => "fenneishi"; + public override string Blurb => + "Fooocus-ControlNet adds more control to the original Fooocus software."; + public override string LicenseType => "GPL-3.0"; + public override string LicenseUrl => + "https://github.com/fenneishi/Fooocus-ControlNet-SDXL/blob/main/LICENSE"; + public override string LaunchCommand => "launch.py"; + public override Uri PreviewImageUri => + new("https://github.com/fenneishi/Fooocus-ControlNet-SDXL/raw/main/asset/canny/snip.png"); + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced; +} diff --git a/StabilityMatrix.Core/Models/Packages/Fooocus.cs b/StabilityMatrix.Core/Models/Packages/Fooocus.cs index 424f3553..b4ad8830 100644 --- a/StabilityMatrix.Core/Models/Packages/Fooocus.cs +++ b/StabilityMatrix.Core/Models/Packages/Fooocus.cs @@ -1,14 +1,17 @@ using System.Diagnostics; using System.Text.RegularExpressions; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class Fooocus : BaseGitPackage { public Fooocus( @@ -38,6 +41,12 @@ public class Fooocus : BaseGitPackage public override List LaunchOptions => new() { + new LaunchOptionDefinition + { + Name = "Preset", + Type = LaunchOptionType.Bool, + Options = { "--preset anime", "--preset realistic" } + }, new LaunchOptionDefinition { Name = "Port", @@ -59,6 +68,49 @@ public class Fooocus : BaseGitPackage Description = "Set the listen interface", Options = { "--listen" } }, + new LaunchOptionDefinition + { + Name = "Output Directory", + Type = LaunchOptionType.String, + Description = "Override the output directory", + Options = { "--output-directory" } + }, + new() + { + Name = "VRAM", + Type = LaunchOptionType.Bool, + InitialValue = HardwareHelper + .IterGpuInfo() + .Select(gpu => gpu.MemoryLevel) + .Max() switch + { + Level.Low => "--lowvram", + Level.Medium => "--normalvram", + _ => null + }, + Options = { "--highvram", "--normalvram", "--lowvram", "--novram" } + }, + new LaunchOptionDefinition + { + Name = "Use DirectML", + Type = LaunchOptionType.Bool, + Description = "Use pytorch with DirectML support", + InitialValue = HardwareHelper.PreferDirectML(), + Options = { "--directml" } + }, + new LaunchOptionDefinition + { + Name = "Disable Xformers", + Type = LaunchOptionType.Bool, + InitialValue = !HardwareHelper.HasNvidiaGpu(), + Options = { "--disable-xformers" } + }, + new LaunchOptionDefinition + { + Name = "Auto-Launch", + Type = LaunchOptionType.Bool, + Options = { "--auto-launch" } + }, LaunchOptionDefinition.Extras }; @@ -83,48 +135,59 @@ public class Fooocus : BaseGitPackage [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" } }; + public override Dictionary>? SharedOutputFolders => + new() { [SharedOutputType.Text2Img] = new[] { "outputs" } }; + public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm }; + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm }; public override Task GetLatestVersion() => Task.FromResult("main"); public override bool ShouldIgnoreReleases => true; + public override string OutputFolderName => "outputs"; + + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Simple; + public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) { - await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); var venvRunner = await SetupVenv(installLocation, forceRecreate: true) .ConfigureAwait(false); progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true)); - var torchVersionStr = "cpu"; - - switch (torchVersion) + if (torchVersion == TorchVersion.DirectMl) { - case TorchVersion.Cuda: - torchVersionStr = "cu118"; - break; - case TorchVersion.Rocm: - torchVersionStr = "rocm5.4.2"; - break; - case TorchVersion.Cpu: - break; - default: - throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null); + await venvRunner + .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput) + .ConfigureAwait(false); + } + else + { + var extraIndex = torchVersion switch + { + TorchVersion.Cpu => "cpu", + TorchVersion.Cuda => "cu121", + TorchVersion.Rocm => "rocm5.6", + _ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null) + }; + + await venvRunner + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.1.0") + .WithTorchVision("==0.16.0") + .WithTorchExtraIndex(extraIndex), + onConsoleOutput + ) + .ConfigureAwait(false); } - - await venvRunner - .PipInstall( - $"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersionStr}", - onConsoleOutput - ) - .ConfigureAwait(false); var requirements = new FilePath(installLocation, "requirements_versions.txt"); await venvRunner diff --git a/StabilityMatrix.Core/Models/Packages/FooocusMre.cs b/StabilityMatrix.Core/Models/Packages/FooocusMre.cs index fb65ce28..2313d444 100644 --- a/StabilityMatrix.Core/Models/Packages/FooocusMre.cs +++ b/StabilityMatrix.Core/Models/Packages/FooocusMre.cs @@ -1,14 +1,17 @@ using System.Diagnostics; using System.Text.RegularExpressions; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class FooocusMre : BaseGitPackage { public FooocusMre( @@ -37,6 +40,11 @@ public class FooocusMre : BaseGitPackage "https://user-images.githubusercontent.com/130458190/265366059-ce430ea0-0995-4067-98dd-cef1d7dc1ab6.png" ); + public override string Disclaimer => + "This package may no longer receive updates from its author. It may be removed from Stability Matrix in the future."; + + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Impossible; + public override List LaunchOptions => new() { @@ -85,6 +93,9 @@ public class FooocusMre : BaseGitPackage [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" } }; + public override Dictionary>? SharedOutputFolders => + new() { [SharedOutputType.Text2Img] = new[] { "outputs" } }; + public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm }; @@ -94,41 +105,47 @@ public class FooocusMre : BaseGitPackage return release.TagName!; } + public override string OutputFolderName => "outputs"; + public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) { - await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); var venvRunner = await SetupVenv(installLocation, forceRecreate: true) .ConfigureAwait(false); progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true)); - var torchVersionStr = "cpu"; - - switch (torchVersion) + if (torchVersion == TorchVersion.DirectMl) { - case TorchVersion.Cuda: - torchVersionStr = "cu118"; - break; - case TorchVersion.Rocm: - torchVersionStr = "rocm5.4.2"; - break; - case TorchVersion.Cpu: - break; - default: - throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null); + await venvRunner + .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput) + .ConfigureAwait(false); + } + else + { + var extraIndex = torchVersion switch + { + TorchVersion.Cpu => "cpu", + TorchVersion.Cuda => "cu118", + TorchVersion.Rocm => "rocm5.4.2", + _ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null) + }; + + await venvRunner + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision("==0.15.2") + .WithTorchExtraIndex(extraIndex), + onConsoleOutput + ) + .ConfigureAwait(false); } - - await venvRunner - .PipInstall( - $"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersionStr}", - onConsoleOutput - ) - .ConfigureAwait(false); var requirements = new FilePath(installLocation, "requirements_versions.txt"); await venvRunner diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs index 0284143c..2d86d438 100644 --- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs +++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs @@ -1,5 +1,7 @@ -using System.Text.RegularExpressions; +using System.Globalization; +using System.Text.RegularExpressions; using NLog; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; @@ -11,6 +13,7 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class InvokeAI : BaseGitPackage { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -20,11 +23,11 @@ public class InvokeAI : BaseGitPackage public override string DisplayName { get; set; } = "InvokeAI"; public override string Author => "invoke-ai"; public override string LicenseType => "Apache-2.0"; - public override string LicenseUrl => "https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE"; public override string Blurb => "Professional Creative Tools for Stable Diffusion"; public override string LaunchCommand => "invokeai-web"; + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Nightmare; public override IReadOnlyList ExtraLaunchCommands => new[] @@ -43,8 +46,6 @@ public class InvokeAI : BaseGitPackage "https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/docs/assets/canvas_preview.png" ); - public override bool ShouldIgnoreReleases => true; - public override IEnumerable AvailableSharedFolderMethods => new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; @@ -60,15 +61,35 @@ public class InvokeAI : BaseGitPackage public override Dictionary> SharedFolders => new() { - [SharedFolderType.StableDiffusion] = new[] { RelativeRootPath + "/autoimport/main" }, - [SharedFolderType.Lora] = new[] { RelativeRootPath + "/autoimport/lora" }, + [SharedFolderType.StableDiffusion] = new[] + { + Path.Combine(RelativeRootPath, "autoimport", "main") + }, + [SharedFolderType.Lora] = new[] + { + Path.Combine(RelativeRootPath, "autoimport", "lora") + }, [SharedFolderType.TextualInversion] = new[] { - RelativeRootPath + "/autoimport/embedding" + Path.Combine(RelativeRootPath, "autoimport", "embedding") }, - [SharedFolderType.ControlNet] = new[] { RelativeRootPath + "/autoimport/controlnet" }, + [SharedFolderType.ControlNet] = new[] + { + Path.Combine(RelativeRootPath, "autoimport", "controlnet") + } }; + public override Dictionary>? SharedOutputFolders => + new() + { + [SharedOutputType.Text2Img] = new[] + { + Path.Combine("invokeai-root", "outputs", "images") + } + }; + + public override string OutputFolderName => Path.Combine("invokeai-root", "outputs", "images"); + // https://github.com/invoke-ai/InvokeAI/blob/main/docs/features/CONFIGURATION.md public override List LaunchOptions => new List @@ -123,7 +144,11 @@ public class InvokeAI : BaseGitPackage LaunchOptionDefinition.Extras }; - public override Task GetLatestVersion() => Task.FromResult("main"); + public override async Task GetLatestVersion() + { + var release = await GetLatestRelease().ConfigureAwait(false); + return release.TagName!; + } public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm, TorchVersion.Mps }; @@ -138,18 +163,10 @@ public class InvokeAI : BaseGitPackage return base.GetRecommendedTorchVersion(); } - public override Task DownloadPackage( - string installLocation, - DownloadPackageVersionOptions downloadOptions, - IProgress? progress = null - ) - { - return Task.CompletedTask; - } - public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) @@ -157,7 +174,10 @@ public class InvokeAI : BaseGitPackage // Setup venv progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); - await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); + var venvPath = Path.Combine(installLocation, "venv"); + var exists = Directory.Exists(venvPath); + + await using var venvRunner = new PyVenvRunner(venvPath); venvRunner.WorkingDirectory = installLocation; await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false); @@ -165,29 +185,42 @@ public class InvokeAI : BaseGitPackage progress?.Report(new ProgressReport(-1f, "Installing Package", isIndeterminate: true)); var pipCommandArgs = - "InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu"; + "-e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu"; switch (torchVersion) { + // If has Nvidia Gpu, install CUDA version case TorchVersion.Cuda: + await InstallCudaTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); Logger.Info("Starting InvokeAI install (CUDA)..."); pipCommandArgs = - "InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117"; + "-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118"; break; - + // For AMD, Install ROCm version case TorchVersion.Rocm: + await venvRunner + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithExtraIndex("rocm5.4.2"), + onConsoleOutput + ) + .ConfigureAwait(false); Logger.Info("Starting InvokeAI install (ROCm)..."); pipCommandArgs = - "InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; + "-e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; break; - case TorchVersion.Mps: + // For Apple silicon, use MPS Logger.Info("Starting InvokeAI install (MPS)..."); - pipCommandArgs = "InvokeAI --use-pep517"; + pipCommandArgs = "-e . --use-pep517"; break; } - await venvRunner.PipInstall(pipCommandArgs, onConsoleOutput).ConfigureAwait(false); + await venvRunner + .PipInstall($"{pipCommandArgs}{(exists ? " --upgrade" : "")}", onConsoleOutput) + .ConfigureAwait(false); await venvRunner .PipInstall("rich packaging python-dotenv", onConsoleOutput) @@ -207,75 +240,6 @@ public class InvokeAI : BaseGitPackage progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false)); } - public override async Task Update( - InstalledPackage installedPackage, - TorchVersion torchVersion, - IProgress? progress = null, - bool includePrerelease = false, - Action? onConsoleOutput = null - ) - { - progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); - - if (installedPackage.FullPath is null || installedPackage.Version is null) - { - throw new NullReferenceException("Installed package is missing Path and/or Version"); - } - - await using var venvRunner = new PyVenvRunner( - Path.Combine(installedPackage.FullPath, "venv") - ); - venvRunner.WorkingDirectory = installedPackage.FullPath; - venvRunner.EnvironmentVariables = GetEnvVars(installedPackage.FullPath); - - var latestVersion = await GetUpdateVersion(installedPackage).ConfigureAwait(false); - var isReleaseMode = installedPackage.Version.IsReleaseMode; - - var downloadUrl = isReleaseMode - ? $"https://github.com/invoke-ai/InvokeAI/archive/{latestVersion}.zip" - : $"https://github.com/invoke-ai/InvokeAI/archive/refs/heads/{installedPackage.Version.InstalledBranch}.zip"; - - var gpus = HardwareHelper.IterGpuInfo().ToList(); - - progress?.Report(new ProgressReport(-1f, "Installing Package", isIndeterminate: true)); - - var pipCommandArgs = - $"\"InvokeAI @ {downloadUrl}\" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu --upgrade"; - - switch (torchVersion) - { - // If has Nvidia Gpu, install CUDA version - case TorchVersion.Cuda: - Logger.Info("Starting InvokeAI install (CUDA)..."); - pipCommandArgs = - $"\"InvokeAI[xformers] @ {downloadUrl}\" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 --upgrade"; - break; - // For AMD, Install ROCm version - case TorchVersion.Rocm: - Logger.Info("Starting InvokeAI install (ROCm)..."); - pipCommandArgs = - $"\"InvokeAI @ {downloadUrl}\" --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 --upgrade"; - break; - case TorchVersion.Mps: - // For Apple silicon, use MPS - Logger.Info("Starting InvokeAI install (MPS)..."); - pipCommandArgs = $"\"InvokeAI @ {downloadUrl}\" --use-pep517 --upgrade"; - break; - } - - await venvRunner.PipInstall(pipCommandArgs, onConsoleOutput).ConfigureAwait(false); - - progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false)); - - return isReleaseMode - ? new InstalledPackageVersion { InstalledReleaseVersion = latestVersion } - : new InstalledPackageVersion - { - InstalledBranch = installedPackage.Version.InstalledBranch, - InstalledCommitSha = latestVersion - }; - } - public override Task RunPackage( string installedPackagePath, string command, @@ -283,27 +247,6 @@ public class InvokeAI : BaseGitPackage Action? onConsoleOutput ) => RunInvokeCommand(installedPackagePath, command, arguments, true, onConsoleOutput); - private async Task GetUpdateVersion( - InstalledPackage installedPackage, - bool includePrerelease = false - ) - { - if (installedPackage.Version == null) - throw new NullReferenceException("Installed package version is null"); - - if (installedPackage.Version.IsReleaseMode) - { - var releases = await GetAllReleases().ConfigureAwait(false); - var latestRelease = releases.First(x => includePrerelease || !x.Prerelease); - return latestRelease.TagName; - } - - var allCommits = await GetAllCommits(installedPackage.Version.InstalledBranch) - .ConfigureAwait(false); - var latestCommit = allCommits.First(); - return latestCommit.Sha; - } - private async Task RunInvokeCommand( string installedPackagePath, string command, @@ -317,7 +260,6 @@ public class InvokeAI : BaseGitPackage arguments = command switch { "invokeai-configure" => "--yes --skip-sd-weights", - "invokeai-model-install" => "--yes", _ => arguments }; @@ -340,6 +282,21 @@ public class InvokeAI : BaseGitPackage // above the minimum in invokeai.frontend.install.widgets var code = $""" + try: + import os + import shutil + from invokeai.frontend.install import widgets + + _min_cols = widgets.MIN_COLS + _min_lines = widgets.MIN_LINES + + static_size_fn = lambda: os.terminal_size((_min_cols, _min_lines)) + shutil.get_terminal_size = static_size_fn + widgets.get_terminal_size = static_size_fn + except Exception as e: + import warnings + warnings.warn('Could not patch terminal size for InvokeAI' + str(e)) + import sys from {split[0]} import {split[1]} sys.exit({split[1]}()) diff --git a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs new file mode 100644 index 00000000..ed9d92fc --- /dev/null +++ b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs @@ -0,0 +1,251 @@ +using System.Text.RegularExpressions; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Models.Packages; + +[Singleton(typeof(BasePackage))] +public class KohyaSs : BaseGitPackage +{ + public KohyaSs( + IGithubApiCache githubApi, + ISettingsManager settingsManager, + IDownloadService downloadService, + IPrerequisiteHelper prerequisiteHelper + ) + : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } + + public override string Name => "kohya_ss"; + public override string DisplayName { get; set; } = "kohya_ss"; + public override string Author => "bmaltais"; + public override string Blurb => + "A Windows-focused Gradio GUI for Kohya's Stable Diffusion trainers"; + public override string LicenseType => "Apache-2.0"; + public override string LicenseUrl => + "https://github.com/bmaltais/kohya_ss/blob/master/LICENSE.md"; + public override string LaunchCommand => "kohya_gui.py"; + + public override Uri PreviewImageUri => + new( + "https://camo.githubusercontent.com/2170d2204816f428eec57ff87218f06344e0b4d91966343a6c5f0a76df91ec75/68747470733a2f2f696d672e796f75747562652e636f6d2f76692f6b35696d713031757655592f302e6a7067" + ); + public override string OutputFolderName => string.Empty; + + public override bool IsCompatible => HardwareHelper.HasNvidiaGpu(); + + public override TorchVersion GetRecommendedTorchVersion() => TorchVersion.Cuda; + + public override string Disclaimer => + "Nvidia GPU with at least 8GB VRAM is recommended. May be unstable on Linux."; + + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.UltraNightmare; + + public override bool OfferInOneClickInstaller => false; + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.None; + public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.Cuda }; + public override IEnumerable AvailableSharedFolderMethods => + new[] { SharedFolderMethod.None }; + + public override List LaunchOptions => + new() + { + new LaunchOptionDefinition + { + Name = "Listen Address", + Type = LaunchOptionType.String, + DefaultValue = "127.0.0.1", + Options = new List { "--listen" } + }, + new LaunchOptionDefinition + { + Name = "Port", + Type = LaunchOptionType.String, + Options = new List { "--port" } + }, + new LaunchOptionDefinition + { + Name = "Username", + Type = LaunchOptionType.String, + Options = new List { "--username" } + }, + new LaunchOptionDefinition + { + Name = "Password", + Type = LaunchOptionType.String, + Options = new List { "--password" } + }, + new LaunchOptionDefinition + { + Name = "Auto-Launch Browser", + Type = LaunchOptionType.Bool, + Options = new List { "--inbrowser" } + }, + new LaunchOptionDefinition + { + Name = "Share", + Type = LaunchOptionType.Bool, + Options = new List { "--share" } + }, + new LaunchOptionDefinition + { + Name = "Headless", + Type = LaunchOptionType.Bool, + Options = new List { "--headless" } + }, + new LaunchOptionDefinition + { + Name = "Language", + Type = LaunchOptionType.String, + Options = new List { "--language" } + }, + LaunchOptionDefinition.Extras + }; + + public override async Task InstallPackage( + string installLocation, + TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, + IProgress? progress = null, + Action? onConsoleOutput = null + ) + { + if (Compat.IsWindows) + { + progress?.Report( + new ProgressReport(-1f, "Installing prerequisites...", isIndeterminate: true) + ); + await PrerequisiteHelper.InstallTkinterIfNecessary(progress).ConfigureAwait(false); + } + + progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); + // Setup venv + await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); + venvRunner.WorkingDirectory = installLocation; + await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false); + + if (Compat.IsWindows) + { + var setupSmPath = Path.Combine(installLocation, "setup", "setup_sm.py"); + var setupText = """ + import setup_windows + import setup_common + + setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=False) + setup_windows.sync_bits_and_bytes_files() + setup_common.configure_accelerate(run_accelerate=False) + """; + await File.WriteAllTextAsync(setupSmPath, setupText).ConfigureAwait(false); + + // Install + venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput); + await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); + } + else if (Compat.IsLinux) + { + venvRunner.RunDetached( + "setup/setup_linux.py --platform-requirements-file=requirements_linux.txt --no_run_accelerate", + onConsoleOutput + ); + await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); + } + } + + public override async Task RunPackage( + string installedPackagePath, + string command, + string arguments, + Action? onConsoleOutput + ) + { + await SetupVenv(installedPackagePath).ConfigureAwait(false); + + // update gui files to point to venv accelerate + var filesToUpdate = new[] + { + "lora_gui.py", + "dreambooth_gui.py", + "textual_inversion_gui.py", + Path.Combine("library", "wd14_caption_gui.py"), + "finetune_gui.py" + }; + + foreach (var file in filesToUpdate) + { + var path = Path.Combine(installedPackagePath, file); + var text = await File.ReadAllTextAsync(path).ConfigureAwait(false); + var replacementAcceleratePath = Compat.IsWindows + ? @".\\venv\\scripts\\accelerate" + : "./venv/bin/accelerate"; + text = text.Replace( + "run_cmd = f'accelerate launch", + $"run_cmd = f'{replacementAcceleratePath} launch" + ); + await File.WriteAllTextAsync(path, text).ConfigureAwait(false); + } + + void HandleConsoleOutput(ProcessOutput s) + { + onConsoleOutput?.Invoke(s); + + if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase)) + return; + + var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); + var match = regex.Match(s.Text); + if (!match.Success) + return; + + WebUrl = match.Value; + OnStartupComplete(WebUrl); + } + + var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; + + VenvRunner.EnvironmentVariables = GetEnvVars(installedPackagePath); + VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); + } + + public override Dictionary>? SharedFolders { get; } + public override Dictionary< + SharedOutputType, + IReadOnlyList + >? SharedOutputFolders { get; } + + public override async Task GetLatestVersion() + { + var release = await GetLatestRelease().ConfigureAwait(false); + return release.TagName!; + } + + private Dictionary GetEnvVars(string installDirectory) + { + // Set additional required environment variables + var env = new Dictionary(); + if (SettingsManager.Settings.EnvironmentVariables is not null) + { + env.Update(SettingsManager.Settings.EnvironmentVariables); + } + + if (!Compat.IsWindows) + return env; + + var tkPath = Path.Combine( + SettingsManager.LibraryDir, + "Assets", + "Python310", + "tcl", + "tcl8.6" + ); + env["TCL_LIBRARY"] = tkPath; + env["TK_LIBRARY"] = tkPath; + + return env; + } +} diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionDirectMl.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionDirectMl.cs new file mode 100644 index 00000000..3d5bf19b --- /dev/null +++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionDirectMl.cs @@ -0,0 +1,251 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using System.Text.RegularExpressions; +using NLog; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Models.Packages; + +[Singleton(typeof(BasePackage))] +public class StableDiffusionDirectMl : BaseGitPackage +{ + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + public override string Name => "stable-diffusion-webui-directml"; + public override string DisplayName { get; set; } = "Stable Diffusion Web UI"; + public override string Author => "lshqqytiger"; + public override string LicenseType => "AGPL-3.0"; + public override string LicenseUrl => + "https://github.com/lshqqytiger/stable-diffusion-webui-directml/blob/master/LICENSE.txt"; + public override string Blurb => + "A fork of Automatic1111's Stable Diffusion WebUI with DirectML support"; + public override string LaunchCommand => "launch.py"; + public override Uri PreviewImageUri => + new( + "https://github.com/lshqqytiger/stable-diffusion-webui-directml/raw/master/screenshot.png" + ); + + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; + + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Recommended; + + public StableDiffusionDirectMl( + IGithubApiCache githubApi, + ISettingsManager settingsManager, + IDownloadService downloadService, + IPrerequisiteHelper prerequisiteHelper + ) + : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } + + public override Dictionary> SharedFolders => + new() + { + [SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" }, + [SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" }, + [SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" }, + [SharedFolderType.SwinIR] = new[] { "models/SwinIR" }, + [SharedFolderType.Lora] = new[] { "models/Lora" }, + [SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" }, + [SharedFolderType.ApproxVAE] = new[] { "models/VAE-approx" }, + [SharedFolderType.VAE] = new[] { "models/VAE" }, + [SharedFolderType.DeepDanbooru] = new[] { "models/deepbooru" }, + [SharedFolderType.Karlo] = new[] { "models/karlo" }, + [SharedFolderType.TextualInversion] = new[] { "embeddings" }, + [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" }, + [SharedFolderType.ControlNet] = new[] { "models/ControlNet" }, + [SharedFolderType.Codeformer] = new[] { "models/Codeformer" }, + [SharedFolderType.LDSR] = new[] { "models/LDSR" }, + [SharedFolderType.AfterDetailer] = new[] { "models/adetailer" } + }; + + public override Dictionary>? SharedOutputFolders => + new() + { + [SharedOutputType.Extras] = new[] { "outputs/extras-images" }, + [SharedOutputType.Saved] = new[] { "log/images" }, + [SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" }, + [SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" }, + [SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" }, + [SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" } + }; + + [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] + public override List LaunchOptions => + new() + { + new() + { + Name = "Host", + Type = LaunchOptionType.String, + DefaultValue = "localhost", + Options = new() { "--server-name" } + }, + new() + { + Name = "Port", + Type = LaunchOptionType.String, + DefaultValue = "7860", + Options = new() { "--port" } + }, + new() + { + Name = "VRAM", + Type = LaunchOptionType.Bool, + InitialValue = HardwareHelper + .IterGpuInfo() + .Select(gpu => gpu.MemoryLevel) + .Max() switch + { + Level.Low => "--lowvram", + Level.Medium => "--medvram", + _ => null + }, + Options = new() { "--lowvram", "--medvram", "--medvram-sdxl" } + }, + new() + { + Name = "Xformers", + Type = LaunchOptionType.Bool, + InitialValue = HardwareHelper.HasNvidiaGpu(), + Options = new() { "--xformers" } + }, + new() + { + Name = "API", + Type = LaunchOptionType.Bool, + InitialValue = true, + Options = new() { "--api" } + }, + new() + { + Name = "Auto Launch Web UI", + Type = LaunchOptionType.Bool, + InitialValue = false, + Options = new() { "--autolaunch" } + }, + new() + { + Name = "Skip Torch CUDA Check", + Type = LaunchOptionType.Bool, + InitialValue = !HardwareHelper.HasNvidiaGpu(), + Options = new() { "--skip-torch-cuda-test" } + }, + new() + { + Name = "Skip Python Version Check", + Type = LaunchOptionType.Bool, + InitialValue = true, + Options = new() { "--skip-python-version-check" } + }, + new() + { + Name = "No Half", + Type = LaunchOptionType.Bool, + Description = "Do not switch the model to 16-bit floats", + InitialValue = HardwareHelper.HasAmdGpu(), + Options = new() { "--no-half" } + }, + new() + { + Name = "Skip SD Model Download", + Type = LaunchOptionType.Bool, + InitialValue = false, + Options = new() { "--no-download-sd-model" } + }, + new() + { + Name = "Skip Install", + Type = LaunchOptionType.Bool, + Options = new() { "--skip-install" } + }, + LaunchOptionDefinition.Extras + }; + + public override IEnumerable AvailableSharedFolderMethods => + new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; + + public override IEnumerable AvailableTorchVersions => + new[] { TorchVersion.Cpu, TorchVersion.DirectMl }; + + public override Task GetLatestVersion() => Task.FromResult("master"); + + public override bool ShouldIgnoreReleases => true; + + public override string OutputFolderName => "outputs"; + + public override async Task InstallPackage( + string installLocation, + TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, + IProgress? progress = null, + Action? onConsoleOutput = null + ) + { + progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); + // Setup venv + await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); + venvRunner.WorkingDirectory = installLocation; + await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false); + + switch (torchVersion) + { + case TorchVersion.DirectMl: + await InstallDirectMlTorch(venvRunner, progress, onConsoleOutput) + .ConfigureAwait(false); + break; + case TorchVersion.Cpu: + await InstallCpuTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + break; + } + + // Install requirements file + progress?.Report( + new ProgressReport(-1f, "Installing Package Requirements", isIndeterminate: true) + ); + Logger.Info("Installing requirements_versions.txt"); + + var requirements = new FilePath(installLocation, "requirements_versions.txt"); + await venvRunner + .PipInstallFromRequirements(requirements, onConsoleOutput, excludes: "torch") + .ConfigureAwait(false); + + progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false)); + } + + public override async Task RunPackage( + string installedPackagePath, + string command, + string arguments, + Action? onConsoleOutput + ) + { + await SetupVenv(installedPackagePath).ConfigureAwait(false); + + void HandleConsoleOutput(ProcessOutput s) + { + onConsoleOutput?.Invoke(s); + + if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase)) + return; + + var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); + var match = regex.Match(s.Text); + if (!match.Success) + return; + + WebUrl = match.Value; + OnStartupComplete(WebUrl); + } + + var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; + + VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); + } +} diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs new file mode 100644 index 00000000..6189d40e --- /dev/null +++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs @@ -0,0 +1,277 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using System.Text.RegularExpressions; +using NLog; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Models.Packages; + +[Singleton(typeof(BasePackage))] +public class StableDiffusionUx : BaseGitPackage +{ + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + public override string Name => "stable-diffusion-webui-ux"; + public override string DisplayName { get; set; } = "Stable Diffusion Web UI-UX"; + public override string Author => "anapnoe"; + public override string LicenseType => "AGPL-3.0"; + public override string LicenseUrl => + "https://github.com/anapnoe/stable-diffusion-webui-ux/blob/master/LICENSE.txt"; + public override string Blurb => + "A pixel perfect design, mobile friendly, customizable interface that adds accessibility, " + + "ease of use and extended functionallity to the stable diffusion web ui."; + public override string LaunchCommand => "launch.py"; + public override Uri PreviewImageUri => + new( + "https://raw.githubusercontent.com/anapnoe/stable-diffusion-webui-ux/master/screenshot.png" + ); + + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; + + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Simple; + + public StableDiffusionUx( + IGithubApiCache githubApi, + ISettingsManager settingsManager, + IDownloadService downloadService, + IPrerequisiteHelper prerequisiteHelper + ) + : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { } + + public override Dictionary> SharedFolders => + new() + { + [SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" }, + [SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" }, + [SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" }, + [SharedFolderType.SwinIR] = new[] { "models/SwinIR" }, + [SharedFolderType.Lora] = new[] { "models/Lora" }, + [SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" }, + [SharedFolderType.ApproxVAE] = new[] { "models/VAE-approx" }, + [SharedFolderType.VAE] = new[] { "models/VAE" }, + [SharedFolderType.DeepDanbooru] = new[] { "models/deepbooru" }, + [SharedFolderType.Karlo] = new[] { "models/karlo" }, + [SharedFolderType.TextualInversion] = new[] { "embeddings" }, + [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" }, + [SharedFolderType.ControlNet] = new[] { "models/ControlNet" }, + [SharedFolderType.Codeformer] = new[] { "models/Codeformer" }, + [SharedFolderType.LDSR] = new[] { "models/LDSR" }, + [SharedFolderType.AfterDetailer] = new[] { "models/adetailer" } + }; + + public override Dictionary>? SharedOutputFolders => + new() + { + [SharedOutputType.Extras] = new[] { "outputs/extras-images" }, + [SharedOutputType.Saved] = new[] { "log/images" }, + [SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" }, + [SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" }, + [SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" }, + [SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" } + }; + + [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] + public override List LaunchOptions => + new() + { + new() + { + Name = "Host", + Type = LaunchOptionType.String, + DefaultValue = "localhost", + Options = new() { "--server-name" } + }, + new() + { + Name = "Port", + Type = LaunchOptionType.String, + DefaultValue = "7860", + Options = new() { "--port" } + }, + new() + { + Name = "VRAM", + Type = LaunchOptionType.Bool, + InitialValue = HardwareHelper + .IterGpuInfo() + .Select(gpu => gpu.MemoryLevel) + .Max() switch + { + Level.Low => "--lowvram", + Level.Medium => "--medvram", + _ => null + }, + Options = new() { "--lowvram", "--medvram", "--medvram-sdxl" } + }, + new() + { + Name = "Xformers", + Type = LaunchOptionType.Bool, + InitialValue = HardwareHelper.HasNvidiaGpu(), + Options = new() { "--xformers" } + }, + new() + { + Name = "API", + Type = LaunchOptionType.Bool, + InitialValue = true, + Options = new() { "--api" } + }, + new() + { + Name = "Auto Launch Web UI", + Type = LaunchOptionType.Bool, + InitialValue = false, + Options = new() { "--autolaunch" } + }, + new() + { + Name = "Skip Torch CUDA Check", + Type = LaunchOptionType.Bool, + InitialValue = !HardwareHelper.HasNvidiaGpu(), + Options = new() { "--skip-torch-cuda-test" } + }, + new() + { + Name = "Skip Python Version Check", + Type = LaunchOptionType.Bool, + InitialValue = true, + Options = new() { "--skip-python-version-check" } + }, + new() + { + Name = "No Half", + Type = LaunchOptionType.Bool, + Description = "Do not switch the model to 16-bit floats", + InitialValue = HardwareHelper.HasAmdGpu(), + Options = new() { "--no-half" } + }, + new() + { + Name = "Skip SD Model Download", + Type = LaunchOptionType.Bool, + InitialValue = false, + Options = new() { "--no-download-sd-model" } + }, + new() + { + Name = "Skip Install", + Type = LaunchOptionType.Bool, + Options = new() { "--skip-install" } + }, + LaunchOptionDefinition.Extras + }; + + public override IEnumerable AvailableSharedFolderMethods => + new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; + + public override IEnumerable AvailableTorchVersions => + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm }; + + public override Task GetLatestVersion() => Task.FromResult("master"); + + public override bool ShouldIgnoreReleases => true; + + public override string OutputFolderName => "outputs"; + + public override async Task InstallPackage( + string installLocation, + TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, + IProgress? progress = null, + Action? onConsoleOutput = null + ) + { + progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); + // Setup venv + await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); + venvRunner.WorkingDirectory = installLocation; + await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false); + + switch (torchVersion) + { + case TorchVersion.Cpu: + await InstallCpuTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + break; + case TorchVersion.Cuda: + await InstallCudaTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + break; + case TorchVersion.Rocm: + await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + break; + } + + // Install requirements file + progress?.Report( + new ProgressReport(-1f, "Installing Package Requirements", isIndeterminate: true) + ); + Logger.Info("Installing requirements_versions.txt"); + + var requirements = new FilePath(installLocation, "requirements_versions.txt"); + await venvRunner + .PipInstallFromRequirements(requirements, onConsoleOutput, excludes: "torch") + .ConfigureAwait(false); + + progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false)); + } + + public override async Task RunPackage( + string installedPackagePath, + string command, + string arguments, + Action? onConsoleOutput + ) + { + await SetupVenv(installedPackagePath).ConfigureAwait(false); + + void HandleConsoleOutput(ProcessOutput s) + { + onConsoleOutput?.Invoke(s); + + if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase)) + return; + + var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); + var match = regex.Match(s.Text); + if (!match.Success) + return; + + WebUrl = match.Value; + OnStartupComplete(WebUrl); + } + + var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; + + VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); + } + + private async Task InstallRocmTorch( + PyVenvRunner venvRunner, + IProgress? progress = null, + Action? onConsoleOutput = null + ) + { + progress?.Report( + new ProgressReport(-1f, "Installing PyTorch for ROCm", isIndeterminate: true) + ); + + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); + + await venvRunner + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.1.1"), + onConsoleOutput + ) + .ConfigureAwait(false); + } +} diff --git a/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs b/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs index 5be7baf0..6468a831 100644 --- a/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs +++ b/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs @@ -26,6 +26,10 @@ public class UnknownPackage : BasePackage public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; + public override string OutputFolderName { get; } + + public override PackageDifficulty InstallerSortOrder { get; } + public override Task DownloadPackage( string installLocation, DownloadPackageVersionOptions versionOptions, @@ -39,6 +43,7 @@ public class UnknownPackage : BasePackage public override Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) @@ -83,6 +88,16 @@ public class UnknownPackage : BasePackage throw new NotImplementedException(); } + public override Task SetupOutputFolderLinks(DirectoryPath installDirectory) + { + throw new NotImplementedException(); + } + + public override Task RemoveOutputFolderLinks(DirectoryPath installDirectory) + { + throw new NotImplementedException(); + } + public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.Cuda, TorchVersion.Cpu, TorchVersion.Rocm, TorchVersion.DirectMl }; @@ -108,6 +123,7 @@ public class UnknownPackage : BasePackage public override Task Update( InstalledPackage installedPackage, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, bool includePrerelease = false, Action? onConsoleOutput = null @@ -122,6 +138,12 @@ public class UnknownPackage : BasePackage public override List LaunchOptions => new(); + public override Dictionary>? SharedFolders { get; } + public override Dictionary< + SharedOutputType, + IReadOnlyList + >? SharedOutputFolders { get; } + public override Task GetLatestVersion() => Task.FromResult(string.Empty); public override Task GetAllVersionOptions() => diff --git a/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs b/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs index 08fd6fe9..7055fe5d 100644 --- a/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs +++ b/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs @@ -4,6 +4,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.RegularExpressions; using NLog; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.FileInterfaces; @@ -14,6 +15,7 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class VladAutomatic : BaseGitPackage { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -32,9 +34,10 @@ public class VladAutomatic : BaseGitPackage public override bool ShouldIgnoreReleases => true; public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Expert; public override IEnumerable AvailableTorchVersions => - new[] { TorchVersion.Cpu, TorchVersion.Rocm, TorchVersion.DirectMl, TorchVersion.Cuda }; + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm }; public VladAutomatic( IGithubApiCache githubApi, @@ -67,6 +70,19 @@ public class VladAutomatic : BaseGitPackage [SharedFolderType.ControlNet] = new[] { "models/ControlNet" } }; + public override Dictionary>? SharedOutputFolders => + new() + { + [SharedOutputType.Text2Img] = new[] { "outputs/text" }, + [SharedOutputType.Img2Img] = new[] { "outputs/image" }, + [SharedOutputType.Extras] = new[] { "outputs/extras" }, + [SharedOutputType.Img2ImgGrids] = new[] { "outputs/grids" }, + [SharedOutputType.Text2ImgGrids] = new[] { "outputs/grids" }, + [SharedOutputType.Saved] = new[] { "outputs/save" }, + }; + + public override string OutputFolderName => "outputs"; + [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] public override List LaunchOptions => new() @@ -161,6 +177,7 @@ public class VladAutomatic : BaseGitPackage public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) @@ -225,6 +242,7 @@ public class VladAutomatic : BaseGitPackage await PrerequisiteHelper .RunGit( installDir.Parent ?? "", + null, "clone", "https://github.com/vladmandic/automatic", installDir.Name @@ -232,7 +250,7 @@ public class VladAutomatic : BaseGitPackage .ConfigureAwait(false); await PrerequisiteHelper - .RunGit(installLocation, "checkout", downloadOptions.CommitHash) + .RunGit(installLocation, null, "checkout", downloadOptions.CommitHash) .ConfigureAwait(false); } else if (!string.IsNullOrWhiteSpace(downloadOptions.BranchName)) @@ -240,6 +258,7 @@ public class VladAutomatic : BaseGitPackage await PrerequisiteHelper .RunGit( installDir.Parent ?? "", + null, "clone", "-b", downloadOptions.BranchName, @@ -288,16 +307,12 @@ public class VladAutomatic : BaseGitPackage public override async Task Update( InstalledPackage installedPackage, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, bool includePrerelease = false, Action? onConsoleOutput = null ) { - if (installedPackage.Version is null) - { - throw new Exception("Version is null"); - } - progress?.Report( new ProgressReport( -1f, @@ -308,7 +323,12 @@ public class VladAutomatic : BaseGitPackage ); await PrerequisiteHelper - .RunGit(installedPackage.FullPath, "checkout", installedPackage.Version.InstalledBranch) + .RunGit( + installedPackage.FullPath, + onConsoleOutput, + "checkout", + versionOptions.BranchName + ) .ConfigureAwait(false); var venvRunner = new PyVenvRunner(Path.Combine(installedPackage.FullPath!, "venv")); @@ -327,7 +347,7 @@ public class VladAutomatic : BaseGitPackage return new InstalledPackageVersion { - InstalledBranch = installedPackage.Version.InstalledBranch, + InstalledBranch = versionOptions.BranchName, InstalledCommitSha = output.Replace(Environment.NewLine, "").Replace("\n", "") }; } diff --git a/StabilityMatrix.Core/Models/Packages/VoltaML.cs b/StabilityMatrix.Core/Models/Packages/VoltaML.cs index b8ff9733..482e9dc9 100644 --- a/StabilityMatrix.Core/Models/Packages/VoltaML.cs +++ b/StabilityMatrix.Core/Models/Packages/VoltaML.cs @@ -1,4 +1,5 @@ using System.Text.RegularExpressions; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models.Progress; @@ -8,6 +9,7 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Models.Packages; +[Singleton(typeof(BasePackage))] public class VoltaML : BaseGitPackage { public override string Name => "voltaML-fast-stable-diffusion"; @@ -24,6 +26,8 @@ public class VoltaML : BaseGitPackage "https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8" ); + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced; + // There are releases but the manager just downloads the latest commit anyways, // so we'll just limit to commit mode to be more consistent public override bool ShouldIgnoreReleases => true; @@ -45,9 +49,20 @@ public class VoltaML : BaseGitPackage [SharedFolderType.TextualInversion] = new[] { "data/textual-inversion" }, }; + public override Dictionary>? SharedOutputFolders => + new() + { + [SharedOutputType.Text2Img] = new[] { "data/outputs/txt2img" }, + [SharedOutputType.Extras] = new[] { "data/outputs/extra" }, + [SharedOutputType.Img2Img] = new[] { "data/outputs/img2img" }, + }; + + public override string OutputFolderName => "data/outputs"; + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink; - public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.None }; + public override IEnumerable AvailableTorchVersions => + new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Mps }; public override IEnumerable AvailableSharedFolderMethods => new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None }; @@ -136,12 +151,11 @@ public class VoltaML : BaseGitPackage public override async Task InstallPackage( string installLocation, TorchVersion torchVersion, + DownloadPackageVersionOptions versionOptions, IProgress? progress = null, Action? onConsoleOutput = null ) { - await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); - // Setup venv progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true)); await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv")); diff --git a/StabilityMatrix.Core/Models/Settings/Settings.cs b/StabilityMatrix.Core/Models/Settings/Settings.cs index cb01f14e..8997f144 100644 --- a/StabilityMatrix.Core/Models/Settings/Settings.cs +++ b/StabilityMatrix.Core/Models/Settings/Settings.cs @@ -1,4 +1,5 @@ -using System.Globalization; +using System.Drawing; +using System.Globalization; using System.Text.Json.Serialization; using Semver; using StabilityMatrix.Core.Converters.Json; @@ -70,6 +71,11 @@ public class Settings /// public bool IsCompletionRemoveUnderscoresEnabled { get; set; } = true; + /// + /// Format for Inference output image file names + /// + public string? InferenceOutputImageFileNameFormat { get; set; } + /// /// Whether the Inference Image Viewer shows pixel grids at high zoom levels /// @@ -88,6 +94,9 @@ public class Settings public HashSet FavoriteModels { get; set; } = new(); + public Size InferenceImageSize { get; set; } = new(150, 190); + public Size OutputsImageSize { get; set; } = new(300, 300); + public void RemoveInstalledPackageAndUpdateActive(InstalledPackage package) { RemoveInstalledPackageAndUpdateActive(package.Id); diff --git a/StabilityMatrix.Core/Models/Settings/Size.cs b/StabilityMatrix.Core/Models/Settings/Size.cs new file mode 100644 index 00000000..a98e29f0 --- /dev/null +++ b/StabilityMatrix.Core/Models/Settings/Size.cs @@ -0,0 +1,10 @@ +namespace StabilityMatrix.Core.Models.Settings; + +public record struct Size(double Width, double Height) +{ + public static Size operator +(Size current, Size other) => + new(current.Width + other.Width, current.Height + other.Height); + + public static Size operator -(Size current, Size other) => + new(current.Width - other.Width, current.Height - other.Height); +} diff --git a/StabilityMatrix.Core/Models/SharedOutputType.cs b/StabilityMatrix.Core/Models/SharedOutputType.cs new file mode 100644 index 00000000..d4dede8a --- /dev/null +++ b/StabilityMatrix.Core/Models/SharedOutputType.cs @@ -0,0 +1,13 @@ +namespace StabilityMatrix.Core.Models; + +public enum SharedOutputType +{ + All, + Text2Img, + Img2Img, + Extras, + Text2ImgGrids, + Img2ImgGrids, + Saved, + Consolidated +} diff --git a/StabilityMatrix.Core/Models/TrackedDownload.cs b/StabilityMatrix.Core/Models/TrackedDownload.cs index 2411b95c..e4017d52 100644 --- a/StabilityMatrix.Core/Models/TrackedDownload.cs +++ b/StabilityMatrix.Core/Models/TrackedDownload.cs @@ -13,39 +13,40 @@ namespace StabilityMatrix.Core.Models; public class TrackedDownload { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - + [JsonIgnore] private IDownloadService? downloadService; - + [JsonIgnore] private Task? downloadTask; - + [JsonIgnore] private CancellationTokenSource? downloadCancellationTokenSource; - + [JsonIgnore] private CancellationTokenSource? downloadPauseTokenSource; - + [JsonIgnore] private CancellationTokenSource AggregateCancellationTokenSource => CancellationTokenSource.CreateLinkedTokenSource( downloadCancellationTokenSource?.Token ?? CancellationToken.None, - downloadPauseTokenSource?.Token ?? CancellationToken.None); - + downloadPauseTokenSource?.Token ?? CancellationToken.None + ); + public required Guid Id { get; init; } - + public required Uri SourceUrl { get; init; } - + public Uri? RedirectedUrl { get; init; } - + public required DirectoryPath DownloadDirectory { get; init; } - + public required string FileName { get; init; } - + public required string TempFileName { get; init; } - + public string? ExpectedHashSha256 { get; set; } - + [JsonIgnore] [MemberNotNullWhen(true, nameof(ExpectedHashSha256))] public bool ValidateHash => ExpectedHashSha256 is not null; @@ -54,22 +55,24 @@ public class TrackedDownload public ProgressState ProgressState { get; set; } = ProgressState.Inactive; public List ExtraCleanupFileNames { get; init; } = new(); - + // Used for restoring progress on load public long DownloadedBytes { get; set; } public long TotalBytes { get; set; } - + /// /// Optional context action to be invoked on completion /// public IContextAction? ContextAction { get; set; } - + [JsonIgnore] public Exception? Exception { get; private set; } - + + private int attempts; + #region Events private WeakEventManager? progressUpdateEventManager; - + public event EventHandler ProgressUpdate { add @@ -79,18 +82,18 @@ public class TrackedDownload } remove => progressUpdateEventManager?.RemoveEventHandler(value); } - + protected void OnProgressUpdate(ProgressReport e) { // Update downloaded and total bytes DownloadedBytes = Convert.ToInt64(e.Current); TotalBytes = Convert.ToInt64(e.Total); - + progressUpdateEventManager?.RaiseEvent(this, e, nameof(ProgressUpdate)); } - + private WeakEventManager? progressStateChangedEventManager; - + public event EventHandler ProgressStateChanged { add @@ -100,13 +103,13 @@ public class TrackedDownload } remove => progressStateChangedEventManager?.RemoveEventHandler(value); } - + protected void OnProgressStateChanged(ProgressState e) { progressStateChangedEventManager?.RaiseEvent(this, e, nameof(ProgressStateChanged)); } #endregion - + [MemberNotNull(nameof(downloadService))] private void EnsureDownloadService() { @@ -119,42 +122,51 @@ public class TrackedDownload private async Task StartDownloadTask(long resumeFromByte, CancellationToken cancellationToken) { var progress = new Progress(OnProgressUpdate); - + await downloadService!.ResumeDownloadToFileAsync( SourceUrl.ToString(), DownloadDirectory.JoinFile(TempFileName), resumeFromByte, progress, - cancellationToken: cancellationToken).ConfigureAwait(false); - + cancellationToken: cancellationToken + ); + // If hash validation is enabled, validate the hash if (ValidateHash) { - OnProgressUpdate(new ProgressReport(0, isIndeterminate: true, type: ProgressType.Hashing)); - var hash = await FileHash.GetSha256Async(DownloadDirectory.JoinFile(TempFileName), progress).ConfigureAwait(false); + OnProgressUpdate( + new ProgressReport(0, isIndeterminate: true, type: ProgressType.Hashing) + ); + var hash = await FileHash + .GetSha256Async(DownloadDirectory.JoinFile(TempFileName), progress) + .ConfigureAwait(false); if (hash != ExpectedHashSha256?.ToLowerInvariant()) { - throw new Exception($"Hash validation for {FileName} failed, expected {ExpectedHashSha256} but got {hash}"); + throw new Exception( + $"Hash validation for {FileName} failed, expected {ExpectedHashSha256} but got {hash}" + ); } } } - + public void Start() { if (ProgressState != ProgressState.Inactive) { - throw new InvalidOperationException($"Download state must be inactive to start, not {ProgressState}"); + throw new InvalidOperationException( + $"Download state must be inactive to start, not {ProgressState}" + ); } Logger.Debug("Starting download {Download}", FileName); - + EnsureDownloadService(); - + downloadCancellationTokenSource = new CancellationTokenSource(); downloadPauseTokenSource = new CancellationTokenSource(); - + downloadTask = StartDownloadTask(0, AggregateCancellationTokenSource.Token) .ContinueWith(OnDownloadTaskCompleted); - + ProgressState = ProgressState.Working; OnProgressStateChanged(ProgressState); } @@ -163,27 +175,31 @@ public class TrackedDownload { if (ProgressState != ProgressState.Inactive) { - Logger.Warn("Attempted to resume download {Download} but it is not paused ({State})", FileName, ProgressState); + Logger.Warn( + "Attempted to resume download {Download} but it is not paused ({State})", + FileName, + ProgressState + ); } Logger.Debug("Resuming download {Download}", FileName); - + // Read the temp file to get the current size var tempSize = 0L; - + var tempFile = DownloadDirectory.JoinFile(TempFileName); if (tempFile.Exists) { tempSize = tempFile.Info.Length; } - + EnsureDownloadService(); - + downloadCancellationTokenSource = new CancellationTokenSource(); downloadPauseTokenSource = new CancellationTokenSource(); - + downloadTask = StartDownloadTask(tempSize, AggregateCancellationTokenSource.Token) .ContinueWith(OnDownloadTaskCompleted); - + ProgressState = ProgressState.Working; OnProgressStateChanged(ProgressState); } @@ -192,22 +208,30 @@ public class TrackedDownload { if (ProgressState != ProgressState.Working) { - Logger.Warn("Attempted to pause download {Download} but it is not in progress ({State})", FileName, ProgressState); + Logger.Warn( + "Attempted to pause download {Download} but it is not in progress ({State})", + FileName, + ProgressState + ); return; } - + Logger.Debug("Pausing download {Download}", FileName); downloadPauseTokenSource?.Cancel(); } - + public void Cancel() { if (ProgressState is not (ProgressState.Working or ProgressState.Inactive)) { - Logger.Warn("Attempted to cancel download {Download} but it is not in progress ({State})", FileName, ProgressState); + Logger.Warn( + "Attempted to cancel download {Download} but it is not in progress ({State})", + FileName, + ProgressState + ); return; } - + Logger.Debug("Cancelling download {Download}", FileName); // Cancel token if it exists @@ -219,7 +243,7 @@ public class TrackedDownload else { DoCleanup(); - + ProgressState = ProgressState.Cancelled; OnProgressStateChanged(ProgressState); } @@ -238,7 +262,7 @@ public class TrackedDownload { Logger.Warn("Failed to delete temp file {TempFile}", TempFileName); } - + foreach (var extraFile in ExtraCleanupFileNames) { try @@ -251,7 +275,7 @@ public class TrackedDownload } } } - + /// /// Invoked by the task's completion callback /// @@ -272,7 +296,9 @@ public class TrackedDownload } else { - throw new InvalidOperationException("Download task was cancelled but neither cancellation token was cancelled."); + throw new InvalidOperationException( + "Download task was cancelled but neither cancellation token was cancelled." + ); } } // For faulted @@ -281,6 +307,23 @@ public class TrackedDownload // Set the exception Exception = task.Exception; + if ( + (Exception is IOException || Exception?.InnerException is IOException) + && attempts < 3 + ) + { + attempts++; + Logger.Warn( + "Download {Download} failed with {Exception}, retrying ({Attempt})", + FileName, + Exception, + attempts + ); + ProgressState = ProgressState.Inactive; + Resume(); + return; + } + ProgressState = ProgressState.Failed; } // Otherwise success @@ -293,23 +336,23 @@ public class TrackedDownload if (ProgressState is ProgressState.Failed or ProgressState.Cancelled) { DoCleanup(); - } + } else if (ProgressState == ProgressState.Success) { // Move the temp file to the final file DownloadDirectory.JoinFile(TempFileName).MoveTo(DownloadDirectory.JoinFile(FileName)); } - + // For pause, just do nothing - + OnProgressStateChanged(ProgressState); - + // Dispose of the task and cancellation token downloadTask = null; downloadCancellationTokenSource = null; downloadPauseTokenSource = null; } - + public void SetDownloadService(IDownloadService service) { downloadService = service; diff --git a/StabilityMatrix.Core/Processes/Argument.cs b/StabilityMatrix.Core/Processes/Argument.cs new file mode 100644 index 00000000..44f2c846 --- /dev/null +++ b/StabilityMatrix.Core/Processes/Argument.cs @@ -0,0 +1,6 @@ +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +[GenerateOneOf] +public partial class Argument : OneOfBase { } diff --git a/StabilityMatrix.Core/Processes/ProcessArgs.cs b/StabilityMatrix.Core/Processes/ProcessArgs.cs new file mode 100644 index 00000000..8fd296b3 --- /dev/null +++ b/StabilityMatrix.Core/Processes/ProcessArgs.cs @@ -0,0 +1,73 @@ +using System.Collections; +using System.Text.RegularExpressions; +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +/// +/// Parameter type for command line arguments +/// Implicitly converts between string and string[], +/// with no parsing if the input and output types are the same. +/// +public partial class ProcessArgs : OneOfBase, IEnumerable +{ + /// + public ProcessArgs(OneOf input) + : base(input) { } + + /// + /// Whether the argument string contains the given substring, + /// or any of the given arguments if the input is an array. + /// + public bool Contains(string arg) => + Match(str => str.Contains(arg), arr => arr.Any(x => x.Contains(arg))); + + public ProcessArgs Concat(ProcessArgs other) => + Match( + str => new ProcessArgs(string.Join(' ', str, other.ToString())), + arr => new ProcessArgs(arr.Concat(other.ToArray()).ToArray()) + ); + + public ProcessArgs Prepend(ProcessArgs other) => + Match( + str => new ProcessArgs(string.Join(' ', other.ToString(), str)), + arr => new ProcessArgs(other.ToArray().Concat(arr).ToArray()) + ); + + /// + public IEnumerator GetEnumerator() + { + return ToArray().AsEnumerable().GetEnumerator(); + } + + /// + public override string ToString() + { + return Match(str => str, arr => string.Join(' ', arr.Select(ProcessRunner.Quote))); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public string[] ToArray() => + Match( + str => ArgumentsRegex().Matches(str).Select(x => x.Value.Trim('"')).ToArray(), + arr => arr + ); + + // Implicit conversions + + public static implicit operator ProcessArgs(string input) => new(input); + + public static implicit operator ProcessArgs(string[] input) => new(input); + + public static implicit operator string(ProcessArgs input) => input.ToString(); + + public static implicit operator string[](ProcessArgs input) => input.ToArray(); + + [GeneratedRegex("""[\"].+?[\"]|[^ ]+""", RegexOptions.IgnoreCase)] + private static partial Regex ArgumentsRegex(); +} diff --git a/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs new file mode 100644 index 00000000..00a0e812 --- /dev/null +++ b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs @@ -0,0 +1,78 @@ +using System.Diagnostics; +using System.Diagnostics.Contracts; +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +/// +/// Builder for . +/// +public record ProcessArgsBuilder +{ + protected ProcessArgsBuilder() { } + + public ProcessArgsBuilder(params Argument[] arguments) + { + Arguments = arguments.ToList(); + } + + public List Arguments { get; init; } = new(); + + private IEnumerable ToStringArgs() + { + foreach (var argument in Arguments) + { + if (argument.IsT0) + { + yield return argument.AsT0; + } + else + { + yield return argument.AsT1.Item1; + yield return argument.AsT1.Item2; + } + } + } + + /// + public override string ToString() + { + return ToProcessArgs().ToString(); + } + + public ProcessArgs ToProcessArgs() + { + return ToStringArgs().ToArray(); + } + + public static implicit operator ProcessArgs(ProcessArgsBuilder builder) => + builder.ToProcessArgs(); +} + +public static class ProcessArgBuilderExtensions +{ + [Pure] + public static T AddArg(this T builder, Argument argument) + where T : ProcessArgsBuilder + { + return builder with { Arguments = builder.Arguments.Append(argument).ToList() }; + } + + [Pure] + public static T RemoveArgKey(this T builder, string argumentKey) + where T : ProcessArgsBuilder + { + return builder with + { + Arguments = builder.Arguments + .Where( + x => + x.Match( + stringArg => stringArg != argumentKey, + tupleArg => tupleArg.Item1 != argumentKey + ) + ) + .ToList() + }; + } +} diff --git a/StabilityMatrix.Core/Processes/ProcessRunner.cs b/StabilityMatrix.Core/Processes/ProcessRunner.cs index 94763cb7..3a20cc60 100644 --- a/StabilityMatrix.Core/Processes/ProcessRunner.cs +++ b/StabilityMatrix.Core/Processes/ProcessRunner.cs @@ -163,6 +163,55 @@ public static class ProcessRunner return output; } + public static async Task GetProcessResultAsync( + string fileName, + ProcessArgs arguments, + string? workingDirectory = null, + IReadOnlyDictionary? environmentVariables = null + ) + { + Logger.Debug($"Starting process '{fileName}' with arguments '{arguments}'"); + + var info = new ProcessStartInfo + { + FileName = fileName, + Arguments = arguments, + UseShellExecute = false, + RedirectStandardOutput = true, + RedirectStandardError = true, + CreateNoWindow = true + }; + + if (environmentVariables != null) + { + foreach (var (key, value) in environmentVariables) + { + info.EnvironmentVariables[key] = value; + } + } + + if (workingDirectory != null) + { + info.WorkingDirectory = workingDirectory; + } + + using var process = new Process(); + process.StartInfo = info; + StartTrackedProcess(process); + + var stdout = await process.StandardOutput.ReadToEndAsync().ConfigureAwait(false); + var stderr = await process.StandardError.ReadToEndAsync().ConfigureAwait(false); + + await process.WaitForExitAsync().ConfigureAwait(false); + + return new ProcessResult + { + ExitCode = process.ExitCode, + StandardOutput = stdout, + StandardError = stderr + }; + } + public static Process StartProcess( string fileName, string arguments, diff --git a/StabilityMatrix.Core/Python/PipIndexResult.cs b/StabilityMatrix.Core/Python/PipIndexResult.cs new file mode 100644 index 00000000..ea9e62f7 --- /dev/null +++ b/StabilityMatrix.Core/Python/PipIndexResult.cs @@ -0,0 +1,32 @@ +using System.Collections.Immutable; +using System.Text.RegularExpressions; +using StabilityMatrix.Core.Extensions; + +namespace StabilityMatrix.Core.Python; + +public partial record PipIndexResult +{ + public required IReadOnlyList AvailableVersions { get; init; } + + public static PipIndexResult Parse(string output) + { + var match = AvailableVersionsRegex().Matches(output); + + var versions = output + .SplitLines() + .Select(line => AvailableVersionsRegex().Match(line)) + .First(m => m.Success) + .Groups["versions"].Value + .Split( + new[] { ',' }, + StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries + ) + .ToImmutableArray(); + + return new PipIndexResult { AvailableVersions = versions }; + } + + // Regex, capture the line starting with "Available versions:" + [GeneratedRegex(@"^Available versions:\s*(?.*)$")] + private static partial Regex AvailableVersionsRegex(); +} diff --git a/StabilityMatrix.Core/Python/PipInstallArgs.cs b/StabilityMatrix.Core/Python/PipInstallArgs.cs new file mode 100644 index 00000000..d16aedf7 --- /dev/null +++ b/StabilityMatrix.Core/Python/PipInstallArgs.cs @@ -0,0 +1,31 @@ +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Core.Python; + +public record PipInstallArgs : ProcessArgsBuilder +{ + public PipInstallArgs(params Argument[] arguments) + : base(arguments) { } + + public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}"); + + public PipInstallArgs WithTorchDirectML(string version = "") => + this.AddArg($"torch-directml{version}"); + + public PipInstallArgs WithTorchVision(string version = "") => + this.AddArg($"torchvision{version}"); + + public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}"); + + public PipInstallArgs WithExtraIndex(string indexUrl) => + this.AddArg(("--extra-index-url", indexUrl)); + + public PipInstallArgs WithTorchExtraIndex(string index) => + this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}")); + + /// + public override string ToString() + { + return base.ToString(); + } +} diff --git a/StabilityMatrix.Core/Python/PipPackageInfo.cs b/StabilityMatrix.Core/Python/PipPackageInfo.cs new file mode 100644 index 00000000..eee1af99 --- /dev/null +++ b/StabilityMatrix.Core/Python/PipPackageInfo.cs @@ -0,0 +1,7 @@ +namespace StabilityMatrix.Core.Python; + +public readonly record struct PipPackageInfo( + string Name, + string Version, + string? EditableProjectLocation = null +); diff --git a/StabilityMatrix.Core/Python/PipShowResult.cs b/StabilityMatrix.Core/Python/PipShowResult.cs new file mode 100644 index 00000000..4bc132fe --- /dev/null +++ b/StabilityMatrix.Core/Python/PipShowResult.cs @@ -0,0 +1,57 @@ +using StabilityMatrix.Core.Extensions; + +namespace StabilityMatrix.Core.Python; + +public record PipShowResult +{ + public required string Name { get; init; } + + public required string Version { get; init; } + + public string? Summary { get; init; } + + public string? HomePage { get; init; } + + public string? Author { get; init; } + + public string? AuthorEmail { get; init; } + + public string? License { get; init; } + + public string? Location { get; init; } + + public List? Requires { get; init; } + + public List? RequiredBy { get; init; } + + public static PipShowResult Parse(string output) + { + // Decode each line by splitting on first ":" to key and value + var lines = output + .SplitLines(StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries) + .Select(line => line.Split(new[] { ':' }, 2)) + .Where(split => split.Length == 2) + .Select(split => new KeyValuePair(split[0].Trim(), split[1].Trim())) + .ToDictionary(pair => pair.Key, pair => pair.Value); + + return new PipShowResult + { + Name = lines["Name"], + Version = lines["Version"], + Summary = lines.GetValueOrDefault("Summary"), + HomePage = lines.GetValueOrDefault("Home-page"), + Author = lines.GetValueOrDefault("Author"), + AuthorEmail = lines.GetValueOrDefault("Author-email"), + License = lines.GetValueOrDefault("License"), + Location = lines.GetValueOrDefault("Location"), + Requires = lines + .GetValueOrDefault("Requires") + ?.Split(new[] { ',' }, StringSplitOptions.TrimEntries) + .ToList(), + RequiredBy = lines + .GetValueOrDefault("Required-by") + ?.Split(new[] { ',' }, StringSplitOptions.TrimEntries) + .ToList() + }; + } +} diff --git a/StabilityMatrix.Core/Python/PyRunner.cs b/StabilityMatrix.Core/Python/PyRunner.cs index e3e3dac0..ea7268d6 100644 --- a/StabilityMatrix.Core/Python/PyRunner.cs +++ b/StabilityMatrix.Core/Python/PyRunner.cs @@ -1,6 +1,7 @@ using System.Diagnostics.CodeAnalysis; using NLog; using Python.Runtime; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; @@ -11,42 +12,57 @@ using StabilityMatrix.Core.Python.Interop; namespace StabilityMatrix.Core.Python; [SuppressMessage("ReSharper", "NotAccessedPositionalProperty.Global")] -public record struct PyVersionInfo(int Major, int Minor, int Micro, string ReleaseLevel, int Serial); +public record struct PyVersionInfo( + int Major, + int Minor, + int Micro, + string ReleaseLevel, + int Serial +); [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] +[Singleton(typeof(IPyRunner))] public class PyRunner : IPyRunner { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - + // Set by ISettingsManager.TryFindLibrary() public static DirectoryPath HomeDir { get; set; } = string.Empty; - + // This is same for all platforms public const string PythonDirName = "Python310"; - - public static string PythonDir => Path.Combine(GlobalConfig.LibraryDir, "Assets", PythonDirName); + + public static string PythonDir => + Path.Combine(GlobalConfig.LibraryDir, "Assets", PythonDirName); /// /// Path to the Python Linked library relative from the Python directory. /// - public static string RelativePythonDllPath => Compat.Switch( - (PlatformKind.Windows, "python310.dll"), - (PlatformKind.Linux, Path.Combine("lib", "libpython3.10.so")), - (PlatformKind.MacOS, Path.Combine("lib", "libpython3.10.dylib"))); + public static string RelativePythonDllPath => + Compat.Switch( + (PlatformKind.Windows, "python310.dll"), + (PlatformKind.Linux, Path.Combine("lib", "libpython3.10.so")), + (PlatformKind.MacOS, Path.Combine("lib", "libpython3.10.dylib")) + ); public static string PythonDllPath => Path.Combine(PythonDir, RelativePythonDllPath); - public static string PythonExePath => Compat.Switch( - (PlatformKind.Windows, Path.Combine(PythonDir, "python.exe")), - (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "python3")), - (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "python3"))); - public static string PipExePath => Compat.Switch( - (PlatformKind.Windows, Path.Combine(PythonDir, "Scripts", "pip.exe")), - (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "pip3")), - (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "pip3"))); - + public static string PythonExePath => + Compat.Switch( + (PlatformKind.Windows, Path.Combine(PythonDir, "python.exe")), + (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "python3")), + (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "python3")) + ); + public static string PipExePath => + Compat.Switch( + (PlatformKind.Windows, Path.Combine(PythonDir, "Scripts", "pip.exe")), + (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "pip3")), + (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "pip3")) + ); + public static string GetPipPath => Path.Combine(PythonDir, "get-pip.pyc"); - public static string VenvPath => Path.Combine(PythonDir, "Scripts", "virtualenv" + Compat.ExeExtension); + public static string VenvPath => + Path.Combine(PythonDir, "Scripts", "virtualenv" + Compat.ExeExtension); public static bool PipInstalled => File.Exists(PipExePath); public static bool VenvInstalled => File.Exists(VenvPath); @@ -55,7 +71,7 @@ public class PyRunner : IPyRunner public PyIOStream? StdOutStream; public PyIOStream? StdErrStream; - + /// $ /// Initializes the Python runtime using the embedded dll. /// Can be called with no effect after initialization. @@ -63,10 +79,11 @@ public class PyRunner : IPyRunner /// Thrown if Python DLL not found. public async Task Initialize() { - if (PythonEngine.IsInitialized) return; - + if (PythonEngine.IsInitialized) + return; + Logger.Info("Setting PYTHONHOME={PythonDir}", PythonDir.ToRepr()); - + // Append Python path to PATH var newEnvPath = Compat.GetEnvPathWithExtensions(PythonDir); Logger.Debug("Setting PATH={NewEnvPath}", newEnvPath.ToRepr()); @@ -78,7 +95,7 @@ public class PyRunner : IPyRunner { throw new FileNotFoundException("Python linked library not found", PythonDllPath); } - + Runtime.PythonDLL = PythonDllPath; PythonEngine.PythonHome = PythonDir; PythonEngine.Initialize(); @@ -88,12 +105,14 @@ public class PyRunner : IPyRunner StdOutStream = new PyIOStream(); StdErrStream = new PyIOStream(); await RunInThreadWithLock(() => - { - var sys = Py.Import("sys") as PyModule ?? - throw new NullReferenceException("sys module not found"); - sys.Set("stdout", StdOutStream); - sys.Set("stderr", StdErrStream); - }).ConfigureAwait(false); + { + var sys = + Py.Import("sys") as PyModule + ?? throw new NullReferenceException("sys module not found"); + sys.Set("stdout", StdOutStream); + sys.Set("stderr", StdErrStream); + }) + .ConfigureAwait(false); } /// @@ -129,19 +148,26 @@ public class PyRunner : IPyRunner /// Time limit for waiting on PyRunning lock. /// Cancellation token. /// cancelToken was canceled, or waitTimeout expired. - public async Task RunInThreadWithLock(Func func, TimeSpan? waitTimeout = null, CancellationToken cancelToken = default) + public async Task RunInThreadWithLock( + Func func, + TimeSpan? waitTimeout = null, + CancellationToken cancelToken = default + ) { // Wait to acquire PyRunning lock await PyRunning.WaitAsync(cancelToken).ConfigureAwait(false); try { - return await Task.Run(() => - { - using (Py.GIL()) + return await Task.Run( + () => { - return func(); - } - }, cancelToken); + using (Py.GIL()) + { + return func(); + } + }, + cancelToken + ); } finally { @@ -156,19 +182,26 @@ public class PyRunner : IPyRunner /// Time limit for waiting on PyRunning lock. /// Cancellation token. /// cancelToken was canceled, or waitTimeout expired. - public async Task RunInThreadWithLock(Action action, TimeSpan? waitTimeout = null, CancellationToken cancelToken = default) + public async Task RunInThreadWithLock( + Action action, + TimeSpan? waitTimeout = null, + CancellationToken cancelToken = default + ) { // Wait to acquire PyRunning lock await PyRunning.WaitAsync(cancelToken).ConfigureAwait(false); try { - await Task.Run(() => - { - using (Py.GIL()) + await Task.Run( + () => { - action(); - } - }, cancelToken); + using (Py.GIL()) + { + action(); + } + }, + cancelToken + ); } finally { diff --git a/StabilityMatrix.Core/Python/PyVenvRunner.cs b/StabilityMatrix.Core/Python/PyVenvRunner.cs index f35b66ef..5b3ae3dc 100644 --- a/StabilityMatrix.Core/Python/PyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/PyVenvRunner.cs @@ -1,5 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Text; +using System.Text.Json; using System.Text.RegularExpressions; using NLog; using Salaros.Configuration; @@ -9,6 +10,7 @@ using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Processes; +using Yoh.Text.Json.NamingPolicies; namespace StabilityMatrix.Core.Python; @@ -19,20 +21,6 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - private const string TorchPipInstallArgs = "torch==2.0.1 torchvision"; - - public const string TorchPipInstallArgsCuda = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/cu118"; - public const string TorchPipInstallArgsCpu = TorchPipInstallArgs; - public const string TorchPipInstallArgsDirectML = "torch-directml"; - - public const string TorchPipInstallArgsRocm511 = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.1.1"; - public const string TorchPipInstallArgsRocm542 = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; - public const string TorchPipInstallArgsRocmNightly56 = - $"--pre {TorchPipInstallArgs} --index-url https://download.pytorch.org/whl/nightly/rocm5.6"; - /// /// Relative path to the site-packages folder from the venv root. /// This is platform specific. @@ -211,7 +199,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable /// Run a pip install command. Waits for the process to exit. /// workingDirectory defaults to RootPath. /// - public async Task PipInstall(string args, Action? outputDataReceived = null) + public async Task PipInstall(ProcessArgs args, Action? outputDataReceived = null) { if (!File.Exists(PipPath)) { @@ -231,7 +219,43 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable }); SetPyvenvCfg(PyRunner.PythonDir); - RunDetached($"-m pip install {args}", outputAction); + RunDetached(args.Prepend("-m pip install"), outputAction); + await Process.WaitForExitAsync().ConfigureAwait(false); + + // Check return code + if (Process.ExitCode != 0) + { + throw new ProcessException( + $"pip install failed with code {Process.ExitCode}: {output.ToString().ToRepr()}" + ); + } + } + + /// + /// Run a pip uninstall command. Waits for the process to exit. + /// workingDirectory defaults to RootPath. + /// + public async Task PipUninstall(string args, Action? outputDataReceived = null) + { + if (!File.Exists(PipPath)) + { + throw new FileNotFoundException("pip not found", PipPath); + } + + // Record output for errors + var output = new StringBuilder(); + + var outputAction = new Action(s => + { + Logger.Debug($"Pip output: {s.Text}"); + // Record to output + output.Append(s.Text); + // Forward to callback + outputDataReceived?.Invoke(s); + }); + + SetPyvenvCfg(PyRunner.PythonDir); + RunDetached($"-m pip uninstall -y {args}", outputAction); await Process.WaitForExitAsync().ConfigureAwait(false); // Check return code @@ -270,6 +294,149 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable await PipInstall(pipArgs, outputDataReceived).ConfigureAwait(false); } + /// + /// Run a pip list command, return results as PipPackageInfo objects. + /// + public async Task> PipList() + { + if (!File.Exists(PipPath)) + { + throw new FileNotFoundException("pip not found", PipPath); + } + + SetPyvenvCfg(PyRunner.PythonDir); + + var result = await ProcessRunner + .GetProcessResultAsync( + PythonPath, + "-m pip list --format=json", + WorkingDirectory?.FullPath, + EnvironmentVariables + ) + .ConfigureAwait(false); + + // Check return code + if (result.ExitCode != 0) + { + throw new ProcessException( + $"pip list failed with code {result.ExitCode}: {result.StandardOutput}, {result.StandardError}" + ); + } + + // Use only first line, since there might be pip update messages later + if ( + result.StandardOutput + ?.SplitLines(StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries) + .FirstOrDefault() + is not { } firstLine + ) + { + return new List(); + } + + return JsonSerializer.Deserialize>( + firstLine, + new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower + } + ) ?? new List(); + } + + /// + /// Run a pip show command, return results as PipPackageInfo objects. + /// + public async Task PipShow(string packageName) + { + if (!File.Exists(PipPath)) + { + throw new FileNotFoundException("pip not found", PipPath); + } + + SetPyvenvCfg(PyRunner.PythonDir); + + var result = await ProcessRunner + .GetProcessResultAsync( + PythonPath, + new[] { "-m", "pip", "show", packageName }, + WorkingDirectory?.FullPath, + EnvironmentVariables + ) + .ConfigureAwait(false); + + // Check return code + if (result.ExitCode != 0) + { + throw new ProcessException( + $"pip show failed with code {result.ExitCode}: {result.StandardOutput}, {result.StandardError}" + ); + } + + if (result.StandardOutput!.StartsWith("WARNING: Package(s) not found:")) + { + return null; + } + + return PipShowResult.Parse(result.StandardOutput); + } + + /// + /// Run a pip index command, return result as PipIndexResult. + /// + public async Task PipIndex(string packageName, string? indexUrl = null) + { + if (!File.Exists(PipPath)) + { + throw new FileNotFoundException("pip not found", PipPath); + } + + SetPyvenvCfg(PyRunner.PythonDir); + + var args = new ProcessArgsBuilder( + "-m", + "pip", + "index", + "versions", + packageName, + "--no-color", + "--disable-pip-version-check" + ); + + if (indexUrl is not null) + { + args = args.AddArg(("--index-url", indexUrl)); + } + + var result = await ProcessRunner + .GetProcessResultAsync( + PythonPath, + args, + WorkingDirectory?.FullPath, + EnvironmentVariables + ) + .ConfigureAwait(false); + + // Check return code + if (result.ExitCode != 0) + { + throw new ProcessException( + $"pip index failed with code {result.ExitCode}: {result.StandardOutput}, {result.StandardError}" + ); + } + + if ( + string.IsNullOrEmpty(result.StandardOutput) + || result.StandardOutput! + .SplitLines() + .Any(l => l.StartsWith("ERROR: No matching distribution found")) + ) + { + return null; + } + + return PipIndexResult.Parse(result.StandardOutput); + } + /// /// Run a custom install command. Waits for the process to exit. /// workingDirectory defaults to RootPath. @@ -308,7 +475,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable /// Run a command using the venv Python executable and return the result. /// /// Arguments to pass to the Python executable. - public async Task Run(string arguments) + public async Task Run(ProcessArgs arguments) { // Record output for errors var output = new StringBuilder(); @@ -340,12 +507,14 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable [MemberNotNull(nameof(Process))] public void RunDetached( - string arguments, + ProcessArgs args, Action? outputDataReceived, Action? onExit = null, bool unbuffered = true ) { + var arguments = args.ToString(); + if (!PythonPath.Exists) { throw new FileNotFoundException("Venv python not found", PythonPath); @@ -354,10 +523,10 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable Logger.Info( "Launching venv process [{PythonPath}] " - + "in working directory [{WorkingDirectory}] with args {arguments.ToRepr()}", + + "in working directory [{WorkingDirectory}] with args {Arguments}", PythonPath, WorkingDirectory, - arguments.ToRepr() + arguments ); var filteredOutput = @@ -380,7 +549,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable } // Disable pip caching - uses significant memory for large packages like torch - env["PIP_NO_CACHE_DIR"] = "true"; + // env["PIP_NO_CACHE_DIR"] = "true"; // On windows, add portable git to PATH and binary as GIT if (Compat.IsWindows) diff --git a/StabilityMatrix.Core/Services/DownloadService.cs b/StabilityMatrix.Core/Services/DownloadService.cs index 7e0be15a..91dbb573 100644 --- a/StabilityMatrix.Core/Services/DownloadService.cs +++ b/StabilityMatrix.Core/Services/DownloadService.cs @@ -1,10 +1,12 @@ using System.Net.Http.Headers; using Microsoft.Extensions.Logging; using Polly.Contrib.WaitAndRetry; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models.Progress; namespace StabilityMatrix.Core.Services; +[Singleton(typeof(IDownloadService))] public class DownloadService : IDownloadService { private readonly ILogger logger; diff --git a/StabilityMatrix.Core/Services/IImageIndexService.cs b/StabilityMatrix.Core/Services/IImageIndexService.cs index 5962df87..66430063 100644 --- a/StabilityMatrix.Core/Services/IImageIndexService.cs +++ b/StabilityMatrix.Core/Services/IImageIndexService.cs @@ -9,11 +9,6 @@ public interface IImageIndexService { IndexCollection InferenceImages { get; } - /// - /// Gets a list of local images that start with the given path prefix - /// - Task> GetLocalImagesByPrefix(string pathPrefix); - /// /// Refresh index for all collections /// @@ -25,9 +20,4 @@ public interface IImageIndexService /// Refreshes the index of local images in the background /// void BackgroundRefreshIndex(); - - /// - /// Removes a local image from the database - /// - Task RemoveImage(LocalImageFile imageFile); } diff --git a/StabilityMatrix.Core/Services/ISettingsManager.cs b/StabilityMatrix.Core/Services/ISettingsManager.cs index 462469ae..4b8a92d3 100644 --- a/StabilityMatrix.Core/Services/ISettingsManager.cs +++ b/StabilityMatrix.Core/Services/ISettingsManager.cs @@ -21,6 +21,7 @@ public interface ISettingsManager List PackageInstallsInProgress { get; set; } Settings Settings { get; } + DirectoryPath ConsolidatedImagesDirectory { get; } /// /// Event fired when the library directory is changed @@ -69,7 +70,8 @@ public interface ISettingsManager T source, Expression> sourceProperty, Expression> settingsProperty, - bool setInitial = false + bool setInitial = false, + TimeSpan? delay = null ) where T : INotifyPropertyChanged; diff --git a/StabilityMatrix.Core/Services/ImageIndexService.cs b/StabilityMatrix.Core/Services/ImageIndexService.cs index c7ee7739..3b6764b3 100644 --- a/StabilityMatrix.Core/Services/ImageIndexService.cs +++ b/StabilityMatrix.Core/Services/ImageIndexService.cs @@ -1,11 +1,10 @@ using System.Collections.Concurrent; using System.Diagnostics; -using System.Text.Json; using AsyncAwaitBestPractices; using DynamicData; -using DynamicData.Binding; using Microsoft.Extensions.Logging; -using StabilityMatrix.Core.Database; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Database; @@ -13,49 +12,34 @@ using StabilityMatrix.Core.Models.FileInterfaces; namespace StabilityMatrix.Core.Services; +[Singleton(typeof(IImageIndexService))] public class ImageIndexService : IImageIndexService { private readonly ILogger logger; - private readonly ILiteDbContext liteDbContext; private readonly ISettingsManager settingsManager; /// public IndexCollection InferenceImages { get; } - public ImageIndexService( - ILogger logger, - ILiteDbContext liteDbContext, - ISettingsManager settingsManager - ) + public ImageIndexService(ILogger logger, ISettingsManager settingsManager) { this.logger = logger; - this.liteDbContext = liteDbContext; this.settingsManager = settingsManager; InferenceImages = new IndexCollection( this, - file => file.RelativePath + file => file.AbsolutePath ) { - RelativePath = "inference" + RelativePath = "Inference" }; EventManager.Instance.ImageFileAdded += OnImageFileAdded; } - /// - public async Task> GetLocalImagesByPrefix(string pathPrefix) - { - return await liteDbContext.LocalImageFiles - .Query() - .Where(imageFile => imageFile.RelativePath.StartsWith(pathPrefix)) - .ToArrayAsync() - .ConfigureAwait(false); - } - - public async Task RefreshIndexForAllCollections() + public Task RefreshIndexForAllCollections() { - await RefreshIndex(InferenceImages).ConfigureAwait(false); + return RefreshIndex(InferenceImages); } public async Task RefreshIndex(IndexCollection indexCollection) @@ -72,16 +56,17 @@ public class ImageIndexService : IImageIndexService // Start var stopwatch = Stopwatch.StartNew(); - logger.LogInformation("Refreshing images index at {ImagesDir}...", imagesDir); + logger.LogInformation("Refreshing images index at {SearchDir}...", searchDir); var toAdd = new ConcurrentBag(); await Task.Run(() => { - var files = imagesDir.Info + var files = searchDir .EnumerateFiles("*.*", SearchOption.AllDirectories) - .Where(info => LocalImageFile.SupportedImageExtensions.Contains(info.Extension)) - .Select(info => new FilePath(info)); + .Where( + file => LocalImageFile.SupportedImageExtensions.Contains(file.Extension) + ); Parallel.ForEach( files, @@ -95,7 +80,7 @@ public class ImageIndexService : IImageIndexService var indexElapsed = stopwatch.Elapsed; - indexCollection.ItemsSource.EditDiff(toAdd, LocalImageFile.Comparer); + indexCollection.ItemsSource.EditDiff(toAdd); // End stopwatch.Stop(); @@ -120,110 +105,9 @@ public class ImageIndexService : IImageIndexService } } - /*public async Task RefreshIndex(IndexCollection indexCollection) - { - var imagesDir = settingsManager.ImagesDirectory; - if (!imagesDir.Exists) - { - return; - } - - // Start - var stopwatch = Stopwatch.StartNew(); - logger.LogInformation("Refreshing images index..."); - - using var db = await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false); - - var localImageFiles = db.GetCollection("LocalImageFiles")!; - - await localImageFiles.DeleteAllAsync().ConfigureAwait(false); - - // Record start of actual indexing - var indexStart = stopwatch.Elapsed; - - var added = 0; - - foreach ( - var file in imagesDir.Info - .EnumerateFiles("*.*", SearchOption.AllDirectories) - .Where(info => LocalImageFile.SupportedImageExtensions.Contains(info.Extension)) - .Select(info => new FilePath(info)) - ) - { - var relativePath = Path.GetRelativePath(imagesDir, file); - - // Skip if not in sub-path - if (!string.IsNullOrEmpty(subPath) && !relativePath.StartsWith(subPath)) - { - continue; - } - - // TODO: Support other types - const LocalImageFileType imageType = - LocalImageFileType.Inference | LocalImageFileType.TextToImage; - - // Get metadata - using var reader = new BinaryReader(new FileStream(file.FullPath, FileMode.Open)); - var metadata = ImageMetadata.ReadTextChunk(reader, "parameters-json"); - GenerationParameters? genParams = null; - - if (!string.IsNullOrWhiteSpace(metadata)) - { - genParams = JsonSerializer.Deserialize(metadata); - } - else - { - metadata = ImageMetadata.ReadTextChunk(reader, "parameters"); - if (!string.IsNullOrWhiteSpace(metadata)) - { - GenerationParameters.TryParse(metadata, out genParams); - } - } - - var localImage = new LocalImageFile - { - RelativePath = relativePath, - ImageType = imageType, - CreatedAt = file.Info.CreationTimeUtc, - LastModifiedAt = file.Info.LastWriteTimeUtc, - GenerationParameters = genParams - }; - - // Insert into database - await localImageFiles.InsertAsync(localImage).ConfigureAwait(false); - - added++; - } - // Record end of actual indexing - var indexEnd = stopwatch.Elapsed; - - await db.CommitAsync().ConfigureAwait(false); - - // End - stopwatch.Stop(); - var indexDuration = indexEnd - indexStart; - var dbDuration = stopwatch.Elapsed - indexDuration; - - logger.LogInformation( - "Image index updated for {Prefix} with {Entries} files, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)", - subPath, - added, - indexDuration.TotalMilliseconds, - dbDuration.TotalMilliseconds - ); - }*/ - /// public void BackgroundRefreshIndex() { RefreshIndexForAllCollections().SafeFireAndForget(); } - - /// - public async Task RemoveImage(LocalImageFile imageFile) - { - await liteDbContext.LocalImageFiles - .DeleteAsync(imageFile.RelativePath) - .ConfigureAwait(false); - } } diff --git a/StabilityMatrix.Core/Services/ModelIndexService.cs b/StabilityMatrix.Core/Services/ModelIndexService.cs index a70dff68..dec7eeb9 100644 --- a/StabilityMatrix.Core/Services/ModelIndexService.cs +++ b/StabilityMatrix.Core/Services/ModelIndexService.cs @@ -1,6 +1,7 @@ using System.Diagnostics; using AsyncAwaitBestPractices; using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Database; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; @@ -10,6 +11,7 @@ using StabilityMatrix.Core.Models.FileInterfaces; namespace StabilityMatrix.Core.Services; +[Singleton(typeof(IModelIndexService))] public class ModelIndexService : IModelIndexService { private readonly ILogger logger; diff --git a/StabilityMatrix.Core/Services/SettingsManager.cs b/StabilityMatrix.Core/Services/SettingsManager.cs index 30695720..47efd00c 100644 --- a/StabilityMatrix.Core/Services/SettingsManager.cs +++ b/StabilityMatrix.Core/Services/SettingsManager.cs @@ -7,6 +7,7 @@ using System.Text.Json.Serialization; using AsyncAwaitBestPractices; using NLog; using Refit; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.FileInterfaces; @@ -15,6 +16,7 @@ using StabilityMatrix.Core.Python; namespace StabilityMatrix.Core.Services; +[Singleton(typeof(ISettingsManager))] public class SettingsManager : ISettingsManager { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -69,6 +71,7 @@ public class SettingsManager : ISettingsManager public DirectoryPath ImagesDirectory => new(LibraryDir, "Images"); public DirectoryPath ImagesInferenceDirectory => ImagesDirectory.JoinDir("Inference"); + public DirectoryPath ConsolidatedImagesDirectory => ImagesDirectory.JoinDir("Consolidated"); public Settings Settings { get; private set; } = new(); @@ -163,7 +166,8 @@ public class SettingsManager : ISettingsManager T source, Expression> sourceProperty, Expression> settingsProperty, - bool setInitial = false + bool setInitial = false, + TimeSpan? delay = null ) where T : INotifyPropertyChanged { @@ -215,7 +219,14 @@ public class SettingsManager : ISettingsManager if (IsLibraryDirSet) { - SaveSettingsAsync().SafeFireAndForget(); + if (delay != null) + { + SaveSettingsDelayed(delay.Value).SafeFireAndForget(); + } + else + { + SaveSettingsAsync().SafeFireAndForget(); + } } else { @@ -654,4 +665,37 @@ public class SettingsManager : ISettingsManager { return Task.Run(SaveSettings); } + + private CancellationTokenSource? delayedSaveCts; + + private Task SaveSettingsDelayed(TimeSpan delay) + { + var cts = new CancellationTokenSource(); + + var oldCancellationToken = Interlocked.Exchange(ref delayedSaveCts, cts); + + try + { + oldCancellationToken?.Cancel(); + } + catch (ObjectDisposedException) { } + + return Task.Run( + async () => + { + try + { + await Task.Delay(delay, cts.Token); + + await SaveSettingsAsync(); + } + catch (TaskCanceledException) { } + finally + { + cts.Dispose(); + } + }, + CancellationToken.None + ); + } } diff --git a/StabilityMatrix.Core/StabilityMatrix.Core.csproj b/StabilityMatrix.Core/StabilityMatrix.Core.csproj index 1f95863e..70900a96 100644 --- a/StabilityMatrix.Core/StabilityMatrix.Core.csproj +++ b/StabilityMatrix.Core/StabilityMatrix.Core.csproj @@ -15,34 +15,36 @@ - + - + - - - - - + + + + + - - + + - + + + - + - - - - + + + + diff --git a/StabilityMatrix.Core/Updater/UpdateHelper.cs b/StabilityMatrix.Core/Updater/UpdateHelper.cs index 401736c1..325af086 100644 --- a/StabilityMatrix.Core/Updater/UpdateHelper.cs +++ b/StabilityMatrix.Core/Updater/UpdateHelper.cs @@ -2,6 +2,7 @@ using System.Text.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models.Configs; @@ -12,6 +13,7 @@ using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Updater; +[Singleton(typeof(IUpdateHelper))] public class UpdateHelper : IUpdateHelper { private readonly ILogger logger; diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs new file mode 100644 index 00000000..5905aca0 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs @@ -0,0 +1,28 @@ +using System.ComponentModel.DataAnnotations; +using StabilityMatrix.Avalonia.Models.Inference; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class FileNameFormatProviderTests +{ + [TestMethod] + public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow() + { + var provider = new FileNameFormatProvider(); + + var result = provider.Validate("{date}_{time}-{model_name}-{seed}"); + Assert.AreEqual(ValidationResult.Success, result); + } + + [TestMethod] + public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow() + { + var provider = new FileNameFormatProvider(); + + var result = provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}"); + Assert.AreNotEqual(ValidationResult.Success, result); + + Assert.AreEqual("Unknown variable 'invalid'", result.ErrorMessage); + } +} diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs new file mode 100644 index 00000000..0da1eb84 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs @@ -0,0 +1,24 @@ +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class FileNameFormatTests +{ + [TestMethod] + public void TestFileNameFormatParse() + { + var provider = new FileNameFormatProvider + { + GenerationParameters = new GenerationParameters { Seed = 123 }, + ProjectName = "uwu", + ProjectType = InferenceProjectType.TextToImage, + }; + + var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider); + + Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName()); + } +} diff --git a/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs new file mode 100644 index 00000000..9bb09eab --- /dev/null +++ b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs @@ -0,0 +1,61 @@ +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class PipInstallArgsTests +{ + [TestMethod] + public void TestGetTorch() + { + // Arrange + const string version = "==2.1.0"; + + // Act + var args = new PipInstallArgs().WithTorch(version).ToProcessArgs().ToString(); + + // Assert + Assert.AreEqual("torch==2.1.0", args); + } + + [TestMethod] + public void TestGetTorchWithExtraIndex() + { + // Arrange + const string version = ">=2.0.0"; + const string index = "cu118"; + + // Act + var args = new PipInstallArgs() + .WithTorch(version) + .WithTorchVision() + .WithTorchExtraIndex(index) + .ToProcessArgs() + .ToString(); + + // Assert + Assert.AreEqual( + "torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118", + args + ); + } + + [TestMethod] + public void TestGetTorchWithMoreStuff() + { + // Act + var args = new PipInstallArgs() + .AddArg("--pre") + .WithTorch("~=2.0.0") + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu") + .ToString(); + + // Assert + Assert.AreEqual( + "--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu", + args + ); + } +} diff --git a/StabilityMatrix.Tests/Models/GenerationParametersTests.cs b/StabilityMatrix.Tests/Models/GenerationParametersTests.cs new file mode 100644 index 00000000..d22caf8c --- /dev/null +++ b/StabilityMatrix.Tests/Models/GenerationParametersTests.cs @@ -0,0 +1,70 @@ +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Tests.Models; + +[TestClass] +public class GenerationParametersTests +{ + [TestMethod] + public void TestParse() + { + const string data = """ + test123 + Negative prompt: test, easy negative + Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3589107295, Size: 1024x1028, Model hash: 9aa0c3e54d, Model: nightvisionXL_v0770_BakedVAE, VAE hash: 235745af8d, VAE: sdxl_vae.safetensors, Style Selector Enabled: True, Style Selector Randomize: False, Style Selector Style: base, Version: 1.6.0 + """; + + Assert.IsTrue(GenerationParameters.TryParse(data, out var result)); + + Assert.AreEqual("test123", result.PositivePrompt); + Assert.AreEqual("test, easy negative", result.NegativePrompt); + Assert.AreEqual(20, result.Steps); + Assert.AreEqual("Euler a", result.Sampler); + Assert.AreEqual(7, result.CfgScale); + Assert.AreEqual(3589107295, result.Seed); + Assert.AreEqual(1024, result.Width); + Assert.AreEqual(1028, result.Height); + Assert.AreEqual("9aa0c3e54d", result.ModelHash); + Assert.AreEqual("nightvisionXL_v0770_BakedVAE", result.ModelName); + } + + [TestMethod] + public void TestParse_NoNegative() + { + const string data = """ + test123 + Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3589107295, Size: 1024x1028, Model hash: 9aa0c3e54d, Model: nightvisionXL_v0770_BakedVAE, VAE hash: 235745af8d, VAE: sdxl_vae.safetensors, Style Selector Enabled: True, Style Selector Randomize: False, Style Selector Style: base, Version: 1.6.0 + """; + + Assert.IsTrue(GenerationParameters.TryParse(data, out var result)); + + Assert.AreEqual("test123", result.PositivePrompt); + Assert.IsNull(result.NegativePrompt); + Assert.AreEqual(20, result.Steps); + Assert.AreEqual("Euler a", result.Sampler); + Assert.AreEqual(7, result.CfgScale); + Assert.AreEqual(3589107295, result.Seed); + Assert.AreEqual(1024, result.Width); + Assert.AreEqual(1028, result.Height); + Assert.AreEqual("9aa0c3e54d", result.ModelHash); + Assert.AreEqual("nightvisionXL_v0770_BakedVAE", result.ModelName); + } + + [TestMethod] + public void TestParseLineFields() + { + const string lastLine = + @"Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1"; + + var fields = GenerationParameters.ParseLine(lastLine); + + Assert.AreEqual(7, fields.Count); + Assert.AreEqual("30", fields["Steps"]); + Assert.AreEqual("DPM++ 2M Karras", fields["Sampler"]); + Assert.AreEqual("7", fields["CFG scale"]); + Assert.AreEqual("2216407431", fields["Seed"]); + Assert.AreEqual("640x896", fields["Size"]); + Assert.AreEqual("eb2h052f91", fields["Model hash"]); + Assert.AreEqual("anime_v1", fields["Model"]); + } +} diff --git a/StabilityMatrix.Tests/Models/ProcessArgsTests.cs b/StabilityMatrix.Tests/Models/ProcessArgsTests.cs new file mode 100644 index 00000000..ae3281d9 --- /dev/null +++ b/StabilityMatrix.Tests/Models/ProcessArgsTests.cs @@ -0,0 +1,43 @@ +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Tests.Models; + +[TestClass] +public class ProcessArgsTests +{ + [DataTestMethod] + [DataRow("pip", new[] { "pip" })] + [DataRow("pip install torch", new[] { "pip", "install", "torch" })] + [DataRow( + "pip install -r \"file spaces/here\"", + new[] { "pip", "install", "-r", "file spaces/here" } + )] + [DataRow( + "pip install -r \"file spaces\\here\"", + new[] { "pip", "install", "-r", "file spaces\\here" } + )] + public void TestStringToArray(string input, string[] expected) + { + ProcessArgs args = input; + string[] result = args; + CollectionAssert.AreEqual(expected, result); + } + + [DataTestMethod] + [DataRow(new[] { "pip" }, "pip")] + [DataRow(new[] { "pip", "install", "torch" }, "pip install torch")] + [DataRow( + new[] { "pip", "install", "-r", "file spaces/here" }, + "pip install -r \"file spaces/here\"" + )] + [DataRow( + new[] { "pip", "install", "-r", "file spaces\\here" }, + "pip install -r \"file spaces\\here\"" + )] + public void TestArrayToString(string[] input, string expected) + { + ProcessArgs args = input; + string result = args; + Assert.AreEqual(expected, result); + } +} diff --git a/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj b/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj index e82a92a8..f604bbea 100644 --- a/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj +++ b/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj @@ -11,18 +11,18 @@ - + - + - + all runtime; build; native; contentfiles; analyzers; buildtransitive - +