diff --git a/CHANGELOG.md b/CHANGELOG.md
index f69ee68d..bce62ccd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,10 +5,42 @@ All notable changes to Stability Matrix will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2.0.0.html).
+## v2.6.0
+### Added
+- Added **Output Sharing** option for all packages in the three-dots menu on the Packages page
+ - This will link the package's output folders to the relevant subfolders in the "Outputs" directory
+ - When a package only has a generic "outputs" folder, all generated images from that package will be linked to the "Outputs\Text2Img" folder when this option is enabled
+- Added **Outputs page** for viewing generated images from any package, or the shared output folder
+- Added [Stable Diffusion WebUI/UX](https://github.com/anapnoe/stable-diffusion-webui-ux) package
+- Added [Stable Diffusion WebUI-DirectML](https://github.com/lshqqytiger/stable-diffusion-webui-directml) package
+- Added [kohya_ss](https://github.com/bmaltais/kohya_ss) package
+- Added [Fooocus-ControlNet-SDXL](https://github.com/fenneishi/Fooocus-ControlNet-SDXL) package
+- Added GPU compatibility badges to the installers
+- Added filtering of "incompatible" packages (ones that do not support your GPU) to all installers
+ - This can be overridden by checking the new "Show All Packages" checkbox
+- Added more launch options for Fooocus, such as the `--preset` option
+- Added Ctrl+ScrollWheel to change image size in the inference output gallery and new Outputs page
+- Added "No Images Found" placeholder for non-connected models on the Checkpoints tab
+### Changed
+- If ComfyUI for Inference is chosen during the One-Click Installer, the Inference page will be opened after installation instead of the Launch page
+- Changed all package installs & updates to use git commands instead of downloading zip files
+- The One-Click Installer now uses the new progress dialog with console
+- NVIDIA GPU users will be updated to use CUDA 12.1 for ComfyUI & Fooocus packages for a slight performance improvement
+ - Update will occur the next time the package is updated, or on a fresh install
+ - Note: CUDA 12.1 is only available on Maxwell (GTX 900 series) and newer GPUs
+- Improved Model Browser download stability with automatic retries for download errors
+- Optimized page navigation and syntax formatting configurations to improve startup time
+### Fixed
+- Fixed crash when clicking Inference gallery image after the image is deleted externally in file explorer
+- Fixed Inference popup Install button not working on One-Click Installer
+- Fixed Inference Prompt Completion window sometimes not showing while typing
+- Fixed "Show Model Images" toggle on Checkpoints page sometimes displaying cut-off model images
+- Fixed missing httpx package during Automatic1111 install
+- Fixed some instances of localized text being cut off from controls being too small
+
## v2.5.7
### Fixed
-- Fixed error `got an unexpected keyword argument 'socket_options'` on fresh installs of Automatic1111 Stable Diffusion WebUI
-due to missing httpx dependency specification from gradio
+- Fixed error `got an unexpected keyword argument 'socket_options'` on fresh installs of Automatic1111 Stable Diffusion WebUI due to missing httpx dependency specification from gradio
## v2.5.6
### Added
diff --git a/Jenkinsfile b/Jenkinsfile
index 7b493b1a..51dddff3 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -35,11 +35,6 @@ node("Diligence") {
stage('Publish Linux') {
sh "/home/jenkins/.dotnet/tools/pupnet --runtime linux-x64 --kind appimage --app-version ${version} --clean -y"
}
-
- stage ('Archive Artifacts') {
- archiveArtifacts artifacts: 'out/*.exe', followSymlinks: false
- archiveArtifacts artifacts: 'Release/linux-x64/*.AppImage', followSymlinks: false
- }
}
} finally {
stage('Cleanup') {
diff --git a/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj b/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj
index cfbcd729..75091167 100644
--- a/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj
+++ b/StabilityMatrix.Avalonia.Diagnostics/StabilityMatrix.Avalonia.Diagnostics.csproj
@@ -19,12 +19,12 @@
-
-
+
+
-
-
+
+
diff --git a/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs b/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs
index fef4f8db..66a66152 100644
--- a/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs
+++ b/StabilityMatrix.Avalonia/Animations/BetterEntranceNavigationTransition.cs
@@ -1,12 +1,10 @@
using System;
using System.Threading;
-using AsyncAwaitBestPractices;
using Avalonia;
using Avalonia.Animation;
using Avalonia.Animation.Easings;
using Avalonia.Media;
using Avalonia.Styling;
-using FluentAvalonia.UI.Media.Animation;
namespace StabilityMatrix.Avalonia.Animations;
@@ -23,7 +21,7 @@ public class BetterEntranceNavigationTransition : BaseTransitionInfo
/// Gets or sets the Vertical Offset used when animating
///
public double FromVerticalOffset { get; set; } = 100;
-
+
public override async void RunAnimation(Animatable ctrl, CancellationToken cancellationToken)
{
var animation = new Animation
@@ -36,7 +34,7 @@ public class BetterEntranceNavigationTransition : BaseTransitionInfo
Setters =
{
new Setter(Visual.OpacityProperty, 0.0),
- new Setter(TranslateTransform.XProperty,FromHorizontalOffset),
+ new Setter(TranslateTransform.XProperty, FromHorizontalOffset),
new Setter(TranslateTransform.YProperty, FromVerticalOffset)
},
Cue = new Cue(0d)
@@ -46,7 +44,7 @@ public class BetterEntranceNavigationTransition : BaseTransitionInfo
Setters =
{
new Setter(Visual.OpacityProperty, 1d),
- new Setter(TranslateTransform.XProperty,0.0),
+ new Setter(TranslateTransform.XProperty, 0.0),
new Setter(TranslateTransform.YProperty, 0.0)
},
Cue = new Cue(1d)
diff --git a/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs b/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs
index 507d9872..7014e63f 100644
--- a/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs
+++ b/StabilityMatrix.Avalonia/Animations/ItemsRepeaterArrangeAnimation.cs
@@ -2,7 +2,6 @@
using Avalonia;
using Avalonia.Controls;
using Avalonia.Rendering.Composition;
-using Avalonia.Rendering.Composition.Animations;
namespace StabilityMatrix.Avalonia.Animations;
diff --git a/StabilityMatrix.Avalonia/App.axaml b/StabilityMatrix.Avalonia/App.axaml
index 0258b36b..f33a4d5c 100644
--- a/StabilityMatrix.Avalonia/App.axaml
+++ b/StabilityMatrix.Avalonia/App.axaml
@@ -20,6 +20,8 @@
+
+
@@ -56,10 +58,10 @@
-
+
-->
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+
-
+ FontSize="12"
+ Text="{Binding FileNameWithoutExtension}"
+ TextAlignment="Center"
+ TextTrimming="CharacterEllipsis" />
@@ -206,10 +229,10 @@
-
+
-
-
+
+
diff --git a/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs
index b36e1fe1..08a9cca5 100644
--- a/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/ImageFolderCard.axaml.cs
@@ -1,9 +1,23 @@
-using Avalonia.Input;
+using Avalonia.Controls;
+using Avalonia.Controls.Primitives;
+using Avalonia.Input;
+using StabilityMatrix.Avalonia.ViewModels.Inference;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Models.Settings;
namespace StabilityMatrix.Avalonia.Controls;
+[Transient]
public class ImageFolderCard : DropTargetTemplatedControlBase
{
+ private ItemsRepeater? imageRepeater;
+
+ protected override void OnApplyTemplate(TemplateAppliedEventArgs e)
+ {
+ imageRepeater = e.NameScope.Find("ImageRepeater");
+ base.OnApplyTemplate(e);
+ }
+
///
protected override void DropHandler(object? sender, DragEventArgs e)
{
@@ -17,4 +31,30 @@ public class ImageFolderCard : DropTargetTemplatedControlBase
base.DragOverHandler(sender, e);
e.Handled = true;
}
+
+ protected override void OnPointerWheelChanged(PointerWheelEventArgs e)
+ {
+ if (e.KeyModifiers != KeyModifiers.Control)
+ return;
+ if (DataContext is not ImageFolderCardViewModel vm)
+ return;
+
+ if (e.Delta.Y > 0)
+ {
+ if (vm.ImageSize.Height >= 500)
+ return;
+ vm.ImageSize += new Size(15, 19);
+ }
+ else
+ {
+ if (vm.ImageSize.Height <= 200)
+ return;
+ vm.ImageSize -= new Size(15, 19);
+ }
+
+ imageRepeater?.InvalidateArrange();
+ imageRepeater?.InvalidateMeasure();
+
+ e.Handled = true;
+ }
}
diff --git a/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml b/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml
index 3fe4071e..6fb6e844 100644
--- a/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml
+++ b/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml
@@ -13,11 +13,6 @@
-
-
-
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+ IsVisible="{Binding IsRefinerSelectionEnabled}"
+ Text="{x:Static lang:Resources.Label_Refiner}"
+ TextAlignment="Left" />
+
-
-
+
+
-
+ IsVisible="{Binding IsVaeSelectionEnabled}"
+ Text="{x:Static lang:Resources.Label_VAE}"
+ TextAlignment="Left" />
+
-
+
diff --git a/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs
index e997a874..622f1cbe 100644
--- a/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/ModelCard.axaml.cs
@@ -1,9 +1,7 @@
-using Avalonia;
-using Avalonia.Controls;
-using Avalonia.Controls.Primitives;
+using Avalonia.Controls.Primitives;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
-public class ModelCard : TemplatedControl
-{
-}
\ No newline at end of file
+[Transient]
+public class ModelCard : TemplatedControl { }
diff --git a/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs b/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs
index 9342ec12..ef1f1114 100644
--- a/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/Paginator.axaml.cs
@@ -2,7 +2,6 @@
using System.Windows.Input;
using Avalonia;
using Avalonia.Controls.Primitives;
-using AvaloniaEdit.Utils;
using CommunityToolkit.Mvvm.Input;
namespace StabilityMatrix.Avalonia.Controls;
diff --git a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs
index 47a02af1..f5a79dcb 100644
--- a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs
@@ -6,9 +6,12 @@ using AvaloniaEdit;
using AvaloniaEdit.Editing;
using AvaloniaEdit.Utils;
using StabilityMatrix.Avalonia.Helpers;
+using StabilityMatrix.Avalonia.Models;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
+[Transient]
public class PromptCard : TemplatedControl
{
///
@@ -31,7 +34,7 @@ public class PromptCard : TemplatedControl
{
if (editor is not null)
{
- TextEditorConfigs.ConfigForPrompt(editor);
+ TextEditorConfigs.Configure(editor, TextEditorPreset.Prompt);
editor.TextArea.Margin = new Thickness(0, 0, 4, 0);
if (editor.TextArea.ActiveInputHandler is TextAreaInputHandler inputHandler)
diff --git a/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs b/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs
index f5620f12..1943c3ae 100644
--- a/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/RefreshBadge.axaml.cs
@@ -1,7 +1,9 @@
using Avalonia.Markup.Xaml;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
+[Transient]
public partial class RefreshBadge : UserControlBase
{
public RefreshBadge()
diff --git a/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs
index 7e732e0f..ace972b1 100644
--- a/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/SamplerCard.axaml.cs
@@ -1,7 +1,7 @@
using Avalonia.Controls.Primitives;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
-public class SamplerCard : TemplatedControl
-{
-}
+[Transient]
+public class SamplerCard : TemplatedControl { }
diff --git a/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollContentPresenter.cs b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollContentPresenter.cs
new file mode 100644
index 00000000..a5f30ff1
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollContentPresenter.cs
@@ -0,0 +1,14 @@
+using Avalonia.Controls.Presenters;
+using Avalonia.Input;
+
+namespace StabilityMatrix.Avalonia.Controls.Scroll;
+
+public class BetterScrollContentPresenter : ScrollContentPresenter
+{
+ protected override void OnPointerWheelChanged(PointerWheelEventArgs e)
+ {
+ if (e.KeyModifiers == KeyModifiers.Control)
+ return;
+ base.OnPointerWheelChanged(e);
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.axaml b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.axaml
new file mode 100644
index 00000000..48f59f05
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.axaml
@@ -0,0 +1,69 @@
+
+
+
+
+ Item 1
+ Item 2
+ Item 3
+ Item 4
+ Item 5
+ Item 6
+ Item 7
+ Item 8
+ Item 9
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.cs b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.cs
new file mode 100644
index 00000000..05adc971
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Controls/Scroll/BetterScrollViewer.cs
@@ -0,0 +1,5 @@
+using Avalonia.Controls;
+
+namespace StabilityMatrix.Avalonia.Controls.Scroll;
+
+public class BetterScrollViewer : ScrollViewer { }
diff --git a/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs
index f08853dd..06d164e9 100644
--- a/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/SeedCard.axaml.cs
@@ -1,7 +1,7 @@
using Avalonia.Controls.Primitives;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
-public class SeedCard : TemplatedControl
-{
-}
+[Transient]
+public class SeedCard : TemplatedControl { }
diff --git a/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs
index dd748696..726cf414 100644
--- a/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/SelectImageCard.axaml.cs
@@ -1,3 +1,6 @@
-namespace StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
+namespace StabilityMatrix.Avalonia.Controls;
+
+[Transient]
public class SelectImageCard : DropTargetTemplatedControlBase { }
diff --git a/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml
new file mode 100644
index 00000000..c94276b1
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml
@@ -0,0 +1,51 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.cs b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.cs
new file mode 100644
index 00000000..7878bf8b
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.cs
@@ -0,0 +1,56 @@
+using System;
+using AsyncImageLoader;
+using Avalonia;
+using Avalonia.Controls;
+using Avalonia.Controls.Primitives;
+
+namespace StabilityMatrix.Avalonia.Controls.SelectableImageCard;
+
+public class SelectableImageButton : Button
+{
+ public static readonly StyledProperty IsSelectedProperty =
+ ToggleButton.IsCheckedProperty.AddOwner();
+
+ public static readonly StyledProperty SourceProperty =
+ AdvancedImage.SourceProperty.AddOwner();
+
+ public static readonly StyledProperty ImageWidthProperty = AvaloniaProperty.Register<
+ SelectableImageButton,
+ double
+ >("ImageWidth", 300);
+
+ public static readonly StyledProperty ImageHeightProperty = AvaloniaProperty.Register<
+ SelectableImageButton,
+ double
+ >("ImageHeight", 300);
+
+ static SelectableImageButton()
+ {
+ AffectsRender(ImageWidthProperty, ImageHeightProperty);
+ AffectsArrange(ImageWidthProperty, ImageHeightProperty);
+ }
+
+ public double ImageHeight
+ {
+ get => GetValue(ImageHeightProperty);
+ set => SetValue(ImageHeightProperty, value);
+ }
+
+ public double ImageWidth
+ {
+ get => GetValue(ImageWidthProperty);
+ set => SetValue(ImageWidthProperty, value);
+ }
+
+ public bool? IsSelected
+ {
+ get => GetValue(IsSelectedProperty);
+ set => SetValue(IsSelectedProperty, value);
+ }
+
+ public string? Source
+ {
+ get => GetValue(SourceProperty);
+ set => SetValue(SourceProperty, value);
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs
index 6fd22370..4c2a250d 100644
--- a/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/SharpenCard.axaml.cs
@@ -1,7 +1,7 @@
-using Avalonia;
-using Avalonia.Controls;
-using Avalonia.Controls.Primitives;
+using Avalonia.Controls.Primitives;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
+[Transient]
public class SharpenCard : TemplatedControl { }
diff --git a/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs
index c8374349..2341a96d 100644
--- a/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/StackCard.axaml.cs
@@ -1,14 +1,18 @@
using System.Diagnostics.CodeAnalysis;
using Avalonia;
using Avalonia.Controls.Primitives;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+[Transient]
public class StackCard : TemplatedControl
{
- public static readonly StyledProperty SpacingProperty = AvaloniaProperty.Register(
- "Spacing", 8);
+ public static readonly StyledProperty SpacingProperty = AvaloniaProperty.Register<
+ StackCard,
+ int
+ >("Spacing", 8);
public int Spacing
{
diff --git a/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs b/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs
index 3fb2304b..38208891 100644
--- a/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/StackExpander.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia;
using Avalonia.Controls.Primitives;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
+[Transient]
public class StackExpander : TemplatedControl
{
public static readonly StyledProperty SpacingProperty = AvaloniaProperty.Register<
diff --git a/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs
index 083d09fa..7e2e2327 100644
--- a/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs
+++ b/StabilityMatrix.Avalonia/Controls/UpscalerCard.axaml.cs
@@ -1,13 +1,14 @@
-using System;
-using AsyncAwaitBestPractices;
+using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.ViewModels.Inference;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Avalonia.Controls;
+[Transient]
public class UpscalerCard : TemplatedControl
{
///
diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs
index 43d8e751..4d785605 100644
--- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs
+++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs
@@ -1,14 +1,16 @@
using System;
using System.Collections.Generic;
-using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Http;
using System.Text;
using AvaloniaEdit.Utils;
+using DynamicData;
+using DynamicData.Binding;
using Microsoft.Extensions.DependencyInjection;
-using StabilityMatrix.Avalonia.Controls;
+using NSubstitute;
+using NSubstitute.ReturnsExtensions;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.TagCompletion;
@@ -20,6 +22,7 @@ using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.Progress;
using StabilityMatrix.Avalonia.ViewModels.Inference;
+using StabilityMatrix.Avalonia.ViewModels.OutputsPage;
using StabilityMatrix.Avalonia.ViewModels.Settings;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Database;
@@ -29,7 +32,9 @@ using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.Api.Comfy;
+using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.PackageModification;
+using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
@@ -52,7 +57,7 @@ public static class DesignData
if (isInitialized)
throw new InvalidOperationException("DesignData is already initialized.");
- var services = new ServiceCollection();
+ var services = App.ConfigureServices();
var activePackageId = Guid.NewGuid();
services.AddSingleton(
@@ -106,18 +111,18 @@ public static class DesignData
// Mock services
services
- .AddSingleton()
- .AddSingleton()
- .AddSingleton()
- .AddSingleton()
- .AddSingleton()
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
+ .AddSingleton(Substitute.For())
.AddSingleton()
- .AddSingleton()
.AddSingleton()
.AddSingleton()
- .AddSingleton()
- .AddSingleton()
- .AddSingleton();
+ .AddSingleton();
// Placeholder services that nobody should need during design time
services
@@ -128,12 +133,6 @@ public static class DesignData
.AddSingleton(_ => null!)
.AddSingleton(_ => null!);
- // Using some default service implementations from App
- App.ConfigurePackages(services);
- App.ConfigurePageViewModels(services);
- App.ConfigureDialogViewModels(services);
- App.ConfigureViews(services);
-
// Override Launch page with mock
services.Remove(ServiceDescriptor.Singleton());
services.AddSingleton();
@@ -172,12 +171,61 @@ public static class DesignData
LaunchOptionsViewModel.UpdateFilterCards();
InstallerViewModel = Services.GetRequiredService();
- InstallerViewModel.AvailablePackages = packageFactory
- .GetAllAvailablePackages()
- .ToImmutableArray();
+ InstallerViewModel.AvailablePackages = new ObservableCollectionExtended(
+ packageFactory.GetAllAvailablePackages()
+ );
InstallerViewModel.SelectedPackage = InstallerViewModel.AvailablePackages[0];
InstallerViewModel.ReleaseNotes = "## Release Notes\nThis is a test release note.";
+ ObservableCacheEx.AddOrUpdate(
+ CheckpointsPageViewModel.CheckpointFoldersCache,
+ new CheckpointFolder[]
+ {
+ new(settingsManager, downloadService, modelFinder, notificationService)
+ {
+ DirectoryPath = "Models/StableDiffusion",
+ DisplayedCheckpointFiles = new ObservableCollectionExtended()
+ {
+ new()
+ {
+ FilePath = "~/Models/StableDiffusion/electricity-light.safetensors",
+ Title = "Auroral Background",
+ PreviewImagePath =
+ "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/"
+ + "78fd2a0a-42b6-42b0-9815-81cb11bb3d05/00009-2423234823.jpeg",
+ ConnectedModel = new ConnectedModelInfo
+ {
+ VersionName = "Lightning Auroral",
+ BaseModel = "SD 1.5",
+ ModelName = "Auroral Background",
+ ModelType = CivitModelType.Model,
+ FileMetadata = new CivitFileMetadata
+ {
+ Format = CivitModelFormat.SafeTensor,
+ Fp = CivitModelFpType.fp16,
+ Size = CivitModelSize.pruned,
+ }
+ }
+ },
+ new()
+ {
+ FilePath = "~/Models/Lora/model.safetensors",
+ Title = "Some model"
+ },
+ },
+ },
+ new(settingsManager, downloadService, modelFinder, notificationService)
+ {
+ Title = "Lora",
+ DirectoryPath = "Packages/Lora",
+ DisplayedCheckpointFiles = new ObservableCollectionExtended
+ {
+ new() { FilePath = "~/Models/Lora/lora_v2.pt", Title = "Best Lora v2", }
+ }
+ }
+ }
+ );
+
/*// Checkpoints page
CheckpointsPageViewModel.CheckpointFolders =
new CheckpointFolder[]
@@ -336,6 +384,26 @@ public static class DesignData
public static LaunchPageViewModel LaunchPageViewModel =>
Services.GetRequiredService();
+ public static OutputsPageViewModel OutputsPageViewModel
+ {
+ get
+ {
+ var vm = Services.GetRequiredService();
+ vm.Outputs = new ObservableCollectionExtended
+ {
+ new(
+ new LocalImageFile
+ {
+ AbsolutePath =
+ "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/78fd2a0a-42b6-42b0-9815-81cb11bb3d05/00009-2423234823.jpeg",
+ ImageType = LocalImageFileType.TextToImage
+ }
+ )
+ };
+ return vm;
+ }
+ }
+
public static PackageManagerViewModel PackageManagerViewModel
{
get
@@ -465,6 +533,15 @@ The gallery images are often inpainted, but you will get something very similar
viewModel.EnvVars = new ObservableCollection { new("UWU", "TRUE"), };
});
+ public static PythonPackagesViewModel PythonPackagesViewModel =>
+ DialogFactory.Get(vm =>
+ {
+ vm.AddPackages(
+ new PipPackageInfo("pip", "1.0.0"),
+ new PipPackageInfo("torch", "2.1.0+cu121")
+ );
+ });
+
public static InferenceTextToImageViewModel InferenceTextToImageViewModel =>
DialogFactory.Get(vm =>
{
@@ -603,8 +680,8 @@ The gallery images are often inpainted, but you will get something very similar
get
{
var list = new CompletionList { IsFiltering = true };
- list.CompletionData.AddRange(SampleCompletionData);
- list.FilteredCompletionData.AddRange(list.CompletionData);
+ ExtensionMethods.AddRange(list.CompletionData, SampleCompletionData);
+ ExtensionMethods.AddRange(list.FilteredCompletionData, list.CompletionData);
list.SelectItem("te", true);
return list;
}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs b/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs
deleted file mode 100644
index 6bffe945..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs
+++ /dev/null
@@ -1,12 +0,0 @@
-using System;
-using StabilityMatrix.Core.Api;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockApiFactory : IApiFactory
-{
- public T CreateRefitClient(Uri baseAddress)
- {
- throw new NotImplementedException();
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs b/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs
deleted file mode 100644
index 79851cd7..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs
+++ /dev/null
@@ -1,18 +0,0 @@
-using System;
-using StabilityMatrix.Avalonia.Services;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockDiscordRichPresenceService : IDiscordRichPresenceService
-{
- ///
- public void Dispose()
- {
- GC.SuppressFinalize(this);
- }
-
- ///
- public void UpdateState()
- {
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs
deleted file mode 100644
index b1e08d8b..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs
+++ /dev/null
@@ -1,50 +0,0 @@
-using System;
-using System.IO;
-using System.Threading;
-using System.Threading.Tasks;
-using StabilityMatrix.Core.Models.Progress;
-using StabilityMatrix.Core.Services;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockDownloadService : IDownloadService
-{
- public Task DownloadToFileAsync(
- string downloadUrl,
- string downloadPath,
- IProgress? progress = null,
- string? httpClientName = null,
- CancellationToken cancellationToken = default
- )
- {
- return Task.CompletedTask;
- }
-
- ///
- public Task ResumeDownloadToFileAsync(
- string downloadUrl,
- string downloadPath,
- long existingFileSize,
- IProgress? progress = null,
- string? httpClientName = null,
- CancellationToken cancellationToken = default
- )
- {
- return Task.CompletedTask;
- }
-
- ///
- public Task GetFileSizeAsync(
- string downloadUrl,
- string? httpClientName = null,
- CancellationToken cancellationToken = default
- )
- {
- return Task.FromResult(0L);
- }
-
- public Task GetImageStreamFromUrl(string url)
- {
- return Task.FromResult(new MemoryStream(new byte[24]) as Stream)!;
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs b/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs
deleted file mode 100644
index 452af8c4..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs
+++ /dev/null
@@ -1,12 +0,0 @@
-using System;
-using System.Net.Http;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockHttpClientFactory : IHttpClientFactory
-{
- public HttpClient CreateClient(string name)
- {
- throw new NotImplementedException();
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs b/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs
index d5ee1650..39474975 100644
--- a/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs
+++ b/StabilityMatrix.Avalonia/DesignData/MockImageIndexService.cs
@@ -1,8 +1,8 @@
-using System.Collections.Generic;
using System.Threading.Tasks;
+using DynamicData;
+using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
-using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.DesignData;
@@ -10,47 +10,50 @@ namespace StabilityMatrix.Avalonia.DesignData;
public class MockImageIndexService : IImageIndexService
{
///
- public IndexCollection InferenceImages { get; } =
- new IndexCollection(null!, file => file.RelativePath)
+ public IndexCollection InferenceImages { get; }
+
+ public MockImageIndexService()
+ {
+ InferenceImages = new IndexCollection(
+ this,
+ file => file.AbsolutePath
+ )
{
- RelativePath = "inference"
+ RelativePath = "Inference"
};
+ }
///
- public Task> GetLocalImagesByPrefix(string pathPrefix)
+ public Task RefreshIndexForAllCollections()
{
- return Task.FromResult(
- (IReadOnlyList)
- new LocalImageFile[]
- {
- new()
- {
- RelativePath =
- "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/4a7e00a7-6f18-42d4-87c0-10e792df2640/width=1152",
- },
- new()
- {
- RelativePath =
- "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/a318ac1f-3ad0-48ac-98cc-79126febcc17/width=1024",
- },
- new()
- {
- RelativePath =
- "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/16588c94-6595-4be9-8806-d7e6e22d198c/width=1152",
- }
- }
- );
+ return RefreshIndex(InferenceImages);
}
- ///
- public Task RefreshIndexForAllCollections()
+ private static LocalImageFile GetSampleImage(string url)
{
- return Task.CompletedTask;
+ return new LocalImageFile
+ {
+ AbsolutePath = url,
+ GenerationParameters = GenerationParameters.GetSample(),
+ ImageSize = new System.Drawing.Size(1024, 1024)
+ };
}
///
public Task RefreshIndex(IndexCollection indexCollection)
{
+ var toAdd = new[]
+ {
+ GetSampleImage(
+ "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/a318ac1f-3ad0-48ac-98cc-79126febcc17/width=1024"
+ ),
+ GetSampleImage(
+ "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/16588c94-6595-4be9-8806-d7e6e22d198c/width=1152"
+ )
+ };
+
+ indexCollection.ItemsSource.EditDiff(toAdd);
+
return Task.CompletedTask;
}
@@ -59,10 +62,4 @@ public class MockImageIndexService : IImageIndexService
{
throw new System.NotImplementedException();
}
-
- ///
- public Task RemoveImage(LocalImageFile imageFile)
- {
- throw new System.NotImplementedException();
- }
}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs b/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
index dffaee55..70b928f4 100644
--- a/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
+++ b/StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
@@ -3,6 +3,7 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
+using DynamicData;
using DynamicData.Binding;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
@@ -48,6 +49,17 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
///
public bool CanUserDisconnect => IsConnected && !IsConnecting;
+ public MockInferenceClientManager()
+ {
+ Models.AddRange(
+ new[]
+ {
+ HybridModelFile.FromRemote("v1-5-pruned-emaonly.safetensors"),
+ HybridModelFile.FromRemote("artshaper1.safetensors"),
+ }
+ );
+ }
+
///
public Task CopyImageToInputAsync(
FilePath imageFile,
diff --git a/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs b/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs
deleted file mode 100644
index a8f2996e..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs
+++ /dev/null
@@ -1,62 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Threading.Tasks;
-using LiteDB.Async;
-using StabilityMatrix.Core.Database;
-using StabilityMatrix.Core.Models.Api;
-using StabilityMatrix.Core.Models.Database;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockLiteDbContext : ILiteDbContext
-{
- public LiteDatabaseAsync Database => throw new NotImplementedException();
- public ILiteCollectionAsync CivitModels => throw new NotImplementedException();
- public ILiteCollectionAsync CivitModelVersions =>
- throw new NotImplementedException();
- public ILiteCollectionAsync CivitModelQueryCache =>
- throw new NotImplementedException();
- public ILiteCollectionAsync LocalModelFiles =>
- throw new NotImplementedException();
- public ILiteCollectionAsync InferenceProjects =>
- throw new NotImplementedException();
- public ILiteCollectionAsync LocalImageFiles =>
- throw new NotImplementedException();
-
- public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(
- string hashBlake3
- )
- {
- return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null));
- }
-
- public Task UpsertCivitModelAsync(CivitModel civitModel)
- {
- return Task.FromResult(true);
- }
-
- public Task UpsertCivitModelAsync(IEnumerable civitModels)
- {
- return Task.FromResult(true);
- }
-
- public Task UpsertCivitModelQueryCacheEntryAsync(CivitModelQueryCacheEntry entry)
- {
- return Task.FromResult(true);
- }
-
- public Task GetGithubCacheEntry(string cacheKey)
- {
- return Task.FromResult(null);
- }
-
- public Task UpsertGithubCacheEntry(GithubCacheEntry cacheEntry)
- {
- return Task.FromResult(true);
- }
-
- public void Dispose()
- {
- GC.SuppressFinalize(this);
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs b/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs
deleted file mode 100644
index 4fdde7ad..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs
+++ /dev/null
@@ -1,47 +0,0 @@
-using System;
-using System.Threading.Tasks;
-using Avalonia;
-using Avalonia.Controls.Notifications;
-using StabilityMatrix.Avalonia.Services;
-using StabilityMatrix.Core.Models;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockNotificationService : INotificationService
-{
- public void Initialize(Visual? visual,
- NotificationPosition position = NotificationPosition.BottomRight, int maxItems = 3)
- {
- }
-
- public void Show(INotification notification)
- {
- }
-
- public Task> TryAsync(Task task, string title = "Error", string? message = null,
- NotificationType appearance = NotificationType.Error)
- {
- return Task.FromResult(new TaskResult(default!));
- }
-
- public Task> TryAsync(Task task, string title = "Error", string? message = null,
- NotificationType appearance = NotificationType.Error)
- {
- return Task.FromResult(new TaskResult(true));
- }
-
- public void Show(
- string title,
- string message,
- NotificationType appearance = NotificationType.Information,
- TimeSpan? expiration = null)
- {
- }
-
- public void ShowPersistent(
- string title,
- string message,
- NotificationType appearance = NotificationType.Information)
- {
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs b/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs
deleted file mode 100644
index 35c21160..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs
+++ /dev/null
@@ -1,26 +0,0 @@
-using System.Threading.Tasks;
-using StabilityMatrix.Core.Helper;
-using StabilityMatrix.Core.Models.FileInterfaces;
-using StabilityMatrix.Core.Models.Packages;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockSharedFolders : ISharedFolders
-{
- public void SetupLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory)
- {
- }
-
- public Task UpdateLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory)
- {
- return Task.CompletedTask;
- }
-
- public void RemoveLinksForAllPackages()
- {
- }
-
- public void SetupSharedModelFolders()
- {
- }
-}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs
deleted file mode 100644
index 4522bd2d..00000000
--- a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs
+++ /dev/null
@@ -1,22 +0,0 @@
-using System;
-using System.Collections.Generic;
-using StabilityMatrix.Core.Models;
-using StabilityMatrix.Core.Models.FileInterfaces;
-using StabilityMatrix.Core.Services;
-
-namespace StabilityMatrix.Avalonia.DesignData;
-
-public class MockTrackedDownloadService : ITrackedDownloadService
-{
- ///
- public IEnumerable Downloads => Array.Empty();
-
- ///
- public event EventHandler? DownloadAdded;
-
- ///
- public TrackedDownload NewDownload(Uri downloadUrl, FilePath downloadPath)
- {
- throw new NotImplementedException();
- }
-}
diff --git a/StabilityMatrix.Avalonia/DialogHelper.cs b/StabilityMatrix.Avalonia/DialogHelper.cs
index 15397e48..d99bc8dc 100644
--- a/StabilityMatrix.Avalonia/DialogHelper.cs
+++ b/StabilityMatrix.Avalonia/DialogHelper.cs
@@ -9,16 +9,15 @@ using System.Text.Json;
using Avalonia;
using Avalonia.Controls;
using Avalonia.Data;
+using Avalonia.Layout;
using Avalonia.LogicalTree;
using Avalonia.Media;
using Avalonia.Threading;
using AvaloniaEdit;
using AvaloniaEdit.TextMate;
-using CommunityToolkit.Mvvm.ComponentModel.__Internals;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using Markdown.Avalonia;
-using Markdown.Avalonia.SyntaxHigh.Extensions;
using NLog;
using Refit;
using StabilityMatrix.Avalonia.Controls;
@@ -26,7 +25,6 @@ using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
-using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Services;
using TextMateSharp.Grammars;
using Process = FuzzySharp.Process;
@@ -96,6 +94,18 @@ public static class DialogHelper
Watermark = field.Watermark,
DataContext = field,
};
+
+ if (!string.IsNullOrEmpty(field.InnerLeftText))
+ {
+ textBox.InnerLeftContent = new TextBlock()
+ {
+ Text = field.InnerLeftText,
+ Foreground = Brushes.Gray,
+ VerticalAlignment = VerticalAlignment.Center,
+ Margin = new Thickness(8, 0, -4, 0)
+ };
+ }
+
stackPanel.Children.Add(textBox);
// When IsValid property changes, update invalid count and primary button
@@ -427,7 +437,7 @@ public static class DialogHelper
AllowScrollBelowDocument = false
}
};
- TextEditorConfigs.ConfigForPrompt(textEditor);
+ TextEditorConfigs.Configure(textEditor, TextEditorPreset.Prompt);
textEditor.Document.Text = errorLineFormatted;
textEditor.TextArea.Caret.Offset = textEditor.Document.Lines[0].EndOffset;
@@ -518,15 +528,6 @@ public static class DialogHelper
XamlRoot = App.VisualRoot
};
}
-
- ///
- /// Creates a connection help dialog.
- ///
- public static BetterContentDialog CreateConnectionHelpDialog()
- {
- // TODO
- return new BetterContentDialog();
- }
}
// Text fields
@@ -541,6 +542,9 @@ public sealed class TextBoxField : INotifyPropertyChanged
// Watermark text
public string Watermark { get; init; } = string.Empty;
+ // Inner left value
+ public string? InnerLeftText { get; init; }
+
///
/// Validation action on text changes. Throw exception if invalid.
///
diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
index a6303f21..81691e93 100644
--- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
+++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
@@ -184,7 +184,7 @@ public static class ComfyNodeBuilderExtensions
var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple(
"Refiner_CheckpointLoader",
- modelCardViewModel.SelectedRefiner?.FileName
+ modelCardViewModel.SelectedRefiner?.RelativePath
?? throw new NullReferenceException("Model not selected")
)
);
@@ -282,20 +282,16 @@ public static class ComfyNodeBuilderExtensions
builder.Connections.ImageSize = builder.Connections.LatentSize;
}
- var saveImage = builder.Nodes.AddNamedNode(
+ var previewImage = builder.Nodes.AddNamedNode(
new NamedComfyNode("SaveImage")
{
- ClassType = "SaveImage",
- Inputs = new Dictionary
- {
- ["filename_prefix"] = "Inference/TextToImage",
- ["images"] = builder.Connections.Image
- }
+ ClassType = "PreviewImage",
+ Inputs = new Dictionary { ["images"] = builder.Connections.Image }
}
);
- builder.Connections.OutputNodes.Add(saveImage);
+ builder.Connections.OutputNodes.Add(previewImage);
- return saveImage.Name;
+ return previewImage.Name;
}
}
diff --git a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs
index 28c215d6..a090c6a2 100644
--- a/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs
+++ b/StabilityMatrix.Avalonia/Helpers/ImageProcessor.cs
@@ -13,50 +13,57 @@ public static class ImageProcessor
///
public static (int rows, int columns) GetGridDimensionsFromImageCount(int count)
{
- if (count <= 1) return (1, 1);
- if (count == 2) return (1, 2);
-
+ if (count <= 1)
+ return (1, 1);
+ if (count == 2)
+ return (1, 2);
+
// Prefer one extra row over one extra column,
// the row count will be the floor of the square root
// and the column count will be floor of count / rows
- var rows = (int) Math.Floor(Math.Sqrt(count));
- var columns = (int) Math.Floor((double) count / rows);
+ var rows = (int)Math.Floor(Math.Sqrt(count));
+ var columns = (int)Math.Floor((double)count / rows);
return (rows, columns);
}
-
- public static SKImage CreateImageGrid(
- IReadOnlyList images,
- int spacing = 0)
+
+ public static SKImage CreateImageGrid(IReadOnlyList images, int spacing = 0)
{
+ if (images.Count == 0)
+ throw new ArgumentException("Must have at least one image");
+
var (rows, columns) = GetGridDimensionsFromImageCount(images.Count);
var singleWidth = images[0].Width;
var singleHeight = images[0].Height;
-
+
// Make output image
using var output = new SKBitmap(
- singleWidth * columns + spacing * (columns - 1),
- singleHeight * rows + spacing * (rows - 1));
-
+ singleWidth * columns + spacing * (columns - 1),
+ singleHeight * rows + spacing * (rows - 1)
+ );
+
// Draw images
using var canvas = new SKCanvas(output);
-
- foreach (var (row, column) in
- Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns)))
+
+ foreach (
+ var (row, column) in Enumerable.Range(0, rows).Product(Enumerable.Range(0, columns))
+ )
{
// Stop if we have drawn all images
var index = row * columns + column;
- if (index >= images.Count) break;
-
+ if (index >= images.Count)
+ break;
+
// Get image
var image = images[index];
-
+
// Draw image
var destination = new SKRect(
singleWidth * column + spacing * column,
singleHeight * row + spacing * row,
singleWidth * column + spacing * column + image.Width,
- singleHeight * row + spacing * row + image.Height);
+ singleHeight * row + spacing * row + image.Height
+ );
canvas.DrawImage(image, destination);
}
diff --git a/StabilityMatrix.Avalonia/Helpers/ImageSearcher.cs b/StabilityMatrix.Avalonia/Helpers/ImageSearcher.cs
new file mode 100644
index 00000000..b5d280fe
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Helpers/ImageSearcher.cs
@@ -0,0 +1,95 @@
+using System;
+using FuzzySharp;
+using FuzzySharp.PreProcess;
+using StabilityMatrix.Core.Models.Database;
+
+namespace StabilityMatrix.Avalonia.Helpers;
+
+public class ImageSearcher
+{
+ public int MinimumFuzzScore { get; init; } = 80;
+
+ public ImageSearchOptions SearchOptions { get; init; } = ImageSearchOptions.All;
+
+ public Func GetPredicate(string? searchQuery)
+ {
+ if (string.IsNullOrEmpty(searchQuery))
+ {
+ return _ => true;
+ }
+
+ return file =>
+ {
+ if (file.FileName.Contains(searchQuery, StringComparison.OrdinalIgnoreCase))
+ {
+ return true;
+ }
+
+ if (
+ SearchOptions.HasFlag(ImageSearchOptions.FileName)
+ && Fuzz.WeightedRatio(searchQuery, file.FileName, PreprocessMode.Full)
+ > MinimumFuzzScore
+ )
+ {
+ return true;
+ }
+
+ // Generation params
+ if (file.GenerationParameters is { } parameters)
+ {
+ if (
+ SearchOptions.HasFlag(ImageSearchOptions.PositivePrompt)
+ && (
+ parameters.PositivePrompt?.Contains(
+ searchQuery,
+ StringComparison.OrdinalIgnoreCase
+ ) ?? false
+ )
+ || SearchOptions.HasFlag(ImageSearchOptions.NegativePrompt)
+ && (
+ parameters.NegativePrompt?.Contains(
+ searchQuery,
+ StringComparison.OrdinalIgnoreCase
+ ) ?? false
+ )
+ || SearchOptions.HasFlag(ImageSearchOptions.Seed)
+ && parameters.Seed
+ .ToString()
+ .StartsWith(searchQuery, StringComparison.OrdinalIgnoreCase)
+ || SearchOptions.HasFlag(ImageSearchOptions.Sampler)
+ && (
+ parameters.Sampler?.StartsWith(
+ searchQuery,
+ StringComparison.OrdinalIgnoreCase
+ ) ?? false
+ )
+ || SearchOptions.HasFlag(ImageSearchOptions.ModelName)
+ && (
+ parameters.ModelName?.StartsWith(
+ searchQuery,
+ StringComparison.OrdinalIgnoreCase
+ ) ?? false
+ )
+ )
+ {
+ return true;
+ }
+ }
+
+ return false;
+ };
+ }
+
+ [Flags]
+ public enum ImageSearchOptions
+ {
+ None = 0,
+ FileName = 1 << 0,
+ PositivePrompt = 1 << 1,
+ NegativePrompt = 1 << 2,
+ Seed = 1 << 3,
+ Sampler = 1 << 4,
+ ModelName = 1 << 5,
+ All = int.MaxValue
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs b/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs
index 086d0785..1d729cfe 100644
--- a/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs
+++ b/StabilityMatrix.Avalonia/Helpers/PngDataHelper.cs
@@ -3,7 +3,6 @@ using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
-using Avalonia;
using Force.Crc32;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Core.Models;
@@ -16,6 +15,17 @@ public static class PngDataHelper
private static readonly byte[] Text = { 0x74, 0x45, 0x58, 0x74 };
private static readonly byte[] Iend = { 0x49, 0x45, 0x4E, 0x44 };
+ public static byte[] AddMetadata(
+ Stream inputStream,
+ GenerationParameters generationParameters,
+ InferenceProjectDocument projectDocument
+ )
+ {
+ using var ms = new MemoryStream();
+ inputStream.CopyTo(ms);
+ return AddMetadata(ms.ToArray(), generationParameters, projectDocument);
+ }
+
public static byte[] AddMetadata(
byte[] inputImage,
GenerationParameters generationParameters,
diff --git a/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs b/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs
index 1afa66e9..414fc1a5 100644
--- a/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs
+++ b/StabilityMatrix.Avalonia/Helpers/TagCsvParser.cs
@@ -1,24 +1,21 @@
using System.Collections.Generic;
-using System.Data.Common;
-using System.Globalization;
using System.IO;
using System.Threading.Tasks;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using Sylvan.Data.Csv;
using Sylvan;
-using Sylvan.Data;
namespace StabilityMatrix.Avalonia.Helpers;
public class TagCsvParser
{
private readonly Stream stream;
-
+
public TagCsvParser(Stream stream)
{
this.stream = stream;
}
-
+
public async IAsyncEnumerable ParseAsync()
{
var pool = new StringPool();
@@ -27,10 +24,10 @@ public class TagCsvParser
StringFactory = pool.GetString,
HasHeaders = false,
};
-
+
using var textReader = new StreamReader(stream);
await using var dataReader = await CsvDataReader.CreateAsync(textReader, options);
-
+
while (await dataReader.ReadAsync())
{
var entry = new TagCsvEntry
@@ -42,7 +39,7 @@ public class TagCsvParser
};
yield return entry;
}
-
+
/*var dataBinderOptions = new DataBinderOptions
{
BindingMode = DataBindingMode.Any
@@ -54,17 +51,17 @@ public class TagCsvParser
public async Task> GetDictionaryAsync()
{
var dict = new Dictionary();
-
+
await foreach (var entry in ParseAsync())
{
if (entry.Name is null || entry.Type is null)
{
continue;
}
-
+
dict.Add(entry.Name, entry);
}
-
+
return dict;
}
}
diff --git a/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs b/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs
index 21139027..83a744eb 100644
--- a/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs
+++ b/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs
@@ -1,5 +1,5 @@
-using System.IO;
-using System.Reflection;
+using System;
+using System.IO;
using Avalonia.Media;
using AvaloniaEdit;
using AvaloniaEdit.TextMate;
@@ -18,13 +18,22 @@ public static class TextEditorConfigs
{
public static void Configure(TextEditor editor, TextEditorPreset preset)
{
- if (preset == TextEditorPreset.Prompt)
+ switch (preset)
{
- ConfigForPrompt(editor);
+ case TextEditorPreset.Prompt:
+ ConfigForPrompt(editor);
+ break;
+ case TextEditorPreset.Console:
+ ConfigForConsole(editor);
+ break;
+ case TextEditorPreset.None:
+ break;
+ default:
+ throw new ArgumentOutOfRangeException(nameof(preset), preset, null);
}
}
- public static void ConfigForPrompt(TextEditor editor)
+ private static void ConfigForPrompt(TextEditor editor)
{
const ThemeName themeName = ThemeName.DimmedMonokai;
var registryOptions = new RegistryOptions(themeName);
@@ -58,6 +67,25 @@ public static class TextEditorConfigs
installation.SetTheme(theme);
}
+ private static void ConfigForConsole(TextEditor editor)
+ {
+ var registryOptions = new RegistryOptions(ThemeName.DarkPlus);
+
+ // Config hyperlinks
+ editor.TextArea.Options.EnableHyperlinks = true;
+ editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false;
+ editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral;
+
+ var textMate = editor.InstallTextMate(registryOptions);
+ var scope = registryOptions.GetScopeByLanguageId("log");
+
+ if (scope is null)
+ throw new InvalidOperationException("Scope is null");
+
+ textMate.SetGrammar(scope);
+ textMate.SetTheme(registryOptions.LoadTheme(ThemeName.DarkPlus));
+ }
+
private static IRawTheme GetThemeFromStream(Stream stream)
{
using var reader = new StreamReader(stream);
diff --git a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs
index 07a38fb9..04cb8ffd 100644
--- a/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs
+++ b/StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs
@@ -126,7 +126,11 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper
}
}
- public async Task RunGit(string? workingDirectory = null, params string[] args)
+ public async Task RunGit(
+ string? workingDirectory = null,
+ Action? onProcessOutput = null,
+ params string[] args
+ )
{
var command =
args.Length == 0 ? "git" : "git " + string.Join(" ", args.Select(ProcessRunner.Quote));
@@ -229,6 +233,13 @@ public class UnixPrerequisiteHelper : IPrerequisiteHelper
throw new NotImplementedException();
}
+ [UnsupportedOSPlatform("Linux")]
+ [UnsupportedOSPlatform("macOS")]
+ public Task InstallTkinterIfNecessary(IProgress? progress = null)
+ {
+ throw new PlatformNotSupportedException();
+ }
+
[UnsupportedOSPlatform("Linux")]
[UnsupportedOSPlatform("macOS")]
public Task InstallVcRedistIfNecessary(IProgress? progress = null)
diff --git a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs
index e5ee642f..8df65b94 100644
--- a/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs
+++ b/StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs
@@ -17,70 +17,87 @@ namespace StabilityMatrix.Avalonia.Helpers;
public class WindowsPrerequisiteHelper : IPrerequisiteHelper
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
-
+
private readonly IGitHubClient gitHubClient;
private readonly IDownloadService downloadService;
private readonly ISettingsManager settingsManager;
-
+
private const string VcRedistDownloadUrl = "https://aka.ms/vs/16/release/vc_redist.x64.exe";
-
+ private const string TkinterDownloadUrl =
+ "https://cdn.lykos.ai/tkinter-cpython-embedded-3.10.11-win-x64.zip";
+
private string HomeDir => settingsManager.LibraryDir;
-
+
private string VcRedistDownloadPath => Path.Combine(HomeDir, "vcredist.x64.exe");
private string AssetsDir => Path.Combine(HomeDir, "Assets");
private string SevenZipPath => Path.Combine(AssetsDir, "7za.exe");
-
+
private string PythonDownloadPath => Path.Combine(AssetsDir, "python-3.10.11-embed-amd64.zip");
private string PythonDir => Path.Combine(AssetsDir, "Python310");
private string PythonDllPath => Path.Combine(PythonDir, "python310.dll");
private string PythonLibraryZipPath => Path.Combine(PythonDir, "python310.zip");
private string GetPipPath => Path.Combine(PythonDir, "get-pip.pyc");
+
// Temporary directory to extract venv to during python install
private string VenvTempDir => Path.Combine(PythonDir, "venv");
-
+
private string PortableGitInstallDir => Path.Combine(HomeDir, "PortableGit");
private string PortableGitDownloadPath => Path.Combine(HomeDir, "PortableGit.7z.exe");
private string GitExePath => Path.Combine(PortableGitInstallDir, "bin", "git.exe");
+ private string TkinterZipPath => Path.Combine(AssetsDir, "tkinter.zip");
+ private string TkinterExtractPath => PythonDir;
+ private string TkinterExistsPath => Path.Combine(PythonDir, "tkinter");
public string GitBinPath => Path.Combine(PortableGitInstallDir, "bin");
-
+
public bool IsPythonInstalled => File.Exists(PythonDllPath);
public WindowsPrerequisiteHelper(
IGitHubClient gitHubClient,
- IDownloadService downloadService,
- ISettingsManager settingsManager)
+ IDownloadService downloadService,
+ ISettingsManager settingsManager
+ )
{
this.gitHubClient = gitHubClient;
this.downloadService = downloadService;
this.settingsManager = settingsManager;
}
- public async Task RunGit(string? workingDirectory = null, params string[] args)
+ public async Task RunGit(
+ string? workingDirectory = null,
+ Action? onProcessOutput = null,
+ params string[] args
+ )
{
- var process = ProcessRunner.StartAnsiProcess(GitExePath, args,
+ var process = ProcessRunner.StartAnsiProcess(
+ GitExePath,
+ args,
workingDirectory: workingDirectory,
environmentVariables: new Dictionary
{
- {"PATH", Compat.GetEnvPathWithExtensions(GitBinPath)}
- });
-
+ { "PATH", Compat.GetEnvPathWithExtensions(GitBinPath) }
+ },
+ outputDataReceived: onProcessOutput
+ );
+
await ProcessRunner.WaitForExitConditionAsync(process);
}
public async Task GetGitOutput(string? workingDirectory = null, params string[] args)
{
var process = await ProcessRunner.GetProcessOutputAsync(
- GitExePath, string.Join(" ", args),
+ GitExePath,
+ string.Join(" ", args),
workingDirectory: workingDirectory,
environmentVariables: new Dictionary
{
- {"PATH", Compat.GetEnvPathWithExtensions(GitBinPath)}
- });
-
+ { "PATH", Compat.GetEnvPathWithExtensions(GitBinPath) }
+ }
+ );
+
return process;
}
-
+
public async Task InstallAllIfNecessary(IProgress? progress = null)
{
await InstallVcRedistIfNecessary(progress);
@@ -97,16 +114,20 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
(Assets.SevenZipExecutable, AssetsDir),
(Assets.SevenZipLicense, AssetsDir),
};
-
- progress?.Report(new ProgressReport(0, message: "Unpacking resources", isIndeterminate: true));
-
+
+ progress?.Report(
+ new ProgressReport(0, message: "Unpacking resources", isIndeterminate: true)
+ );
+
Directory.CreateDirectory(AssetsDir);
foreach (var (asset, extractDir) in assets)
{
await asset.ExtractToDir(extractDir);
}
-
- progress?.Report(new ProgressReport(1, message: "Unpacking resources", isIndeterminate: false));
+
+ progress?.Report(
+ new ProgressReport(1, message: "Unpacking resources", isIndeterminate: false)
+ );
}
public async Task InstallPythonIfNecessary(IProgress? progress = null)
@@ -120,7 +141,7 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
Logger.Info("Python not found at {PythonDllPath}, downloading...", PythonDllPath);
Directory.CreateDirectory(AssetsDir);
-
+
// Delete existing python zip if it exists
if (File.Exists(PythonLibraryZipPath))
{
@@ -130,44 +151,45 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
var remote = Assets.PythonDownloadUrl;
var url = remote.Url.ToString();
Logger.Info($"Downloading Python from {url} to {PythonLibraryZipPath}");
-
+
// Cleanup to remove zip if download fails
try
{
// Download python zip
await downloadService.DownloadToFileAsync(url, PythonDownloadPath, progress: progress);
-
+
// Verify python hash
var downloadHash = await FileHash.GetSha256Async(PythonDownloadPath, progress);
if (downloadHash != remote.HashSha256)
{
var fileExists = File.Exists(PythonDownloadPath);
var fileSize = new FileInfo(PythonDownloadPath).Length;
- var msg = $"Python download hash mismatch: {downloadHash} != {remote.HashSha256} " +
- $"(file exists: {fileExists}, size: {fileSize})";
+ var msg =
+ $"Python download hash mismatch: {downloadHash} != {remote.HashSha256} "
+ + $"(file exists: {fileExists}, size: {fileSize})";
throw new Exception(msg);
}
-
+
progress?.Report(new ProgressReport(progress: 1f, message: "Python download complete"));
-
+
progress?.Report(new ProgressReport(-1, "Installing Python...", isIndeterminate: true));
-
+
// We also need 7z if it's not already unpacked
if (!File.Exists(SevenZipPath))
{
await Assets.SevenZipExecutable.ExtractToDir(AssetsDir);
await Assets.SevenZipLicense.ExtractToDir(AssetsDir);
}
-
+
// Delete existing python dir
if (Directory.Exists(PythonDir))
{
Directory.Delete(PythonDir, true);
}
-
+
// Unzip python
await ArchiveHelper.Extract7Z(PythonDownloadPath, PythonDir);
-
+
try
{
// Extract embedded venv folder
@@ -185,7 +207,7 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
await resource.ExtractTo(path);
}
// Add venv to python's library zip
-
+
await ArchiveHelper.AddToArchive7Z(PythonLibraryZipPath, VenvTempDir);
}
finally
@@ -196,16 +218,19 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
Directory.Delete(VenvTempDir, true);
}
}
-
+
// Extract get-pip.pyc
await Assets.PyScriptGetPip.ExtractToDir(PythonDir);
-
+
// We need to uncomment the #import site line in python310._pth for pip to work
var pythonPthPath = Path.Combine(PythonDir, "python310._pth");
var pythonPthContent = await File.ReadAllTextAsync(pythonPthPath);
pythonPthContent = pythonPthContent.Replace("#import site", "import site");
await File.WriteAllTextAsync(pythonPthPath, pythonPthContent);
-
+
+ // Install TKinter
+ await InstallTkinterIfNecessary(progress);
+
progress?.Report(new ProgressReport(1f, "Python install complete"));
}
finally
@@ -218,6 +243,39 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
}
}
+ [SupportedOSPlatform("windows")]
+ public async Task InstallTkinterIfNecessary(IProgress? progress = null)
+ {
+ if (!Directory.Exists(TkinterExistsPath))
+ {
+ Logger.Info("Downloading Tkinter");
+ await downloadService.DownloadToFileAsync(
+ TkinterDownloadUrl,
+ TkinterZipPath,
+ progress: progress
+ );
+ progress?.Report(
+ new ProgressReport(
+ progress: 1f,
+ message: "Tkinter download complete",
+ type: ProgressType.Download
+ )
+ );
+
+ await ArchiveHelper.Extract(TkinterZipPath, TkinterExtractPath, progress);
+
+ File.Delete(TkinterZipPath);
+ }
+
+ progress?.Report(
+ new ProgressReport(
+ progress: 1f,
+ message: "Tkinter install complete",
+ type: ProgressType.Generic
+ )
+ );
+ }
+
public async Task InstallGitIfNecessary(IProgress? progress = null)
{
if (File.Exists(GitExePath))
@@ -225,7 +283,7 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
Logger.Debug("Git already installed at {GitExePath}", GitExePath);
return;
}
-
+
Logger.Info("Git not found at {GitExePath}, downloading...", GitExePath);
var portableGitUrl =
@@ -233,7 +291,11 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
if (!File.Exists(PortableGitDownloadPath))
{
- await downloadService.DownloadToFileAsync(portableGitUrl, PortableGitDownloadPath, progress: progress);
+ await downloadService.DownloadToFileAsync(
+ portableGitUrl,
+ PortableGitDownloadPath,
+ progress: progress
+ );
progress?.Report(new ProgressReport(progress: 1f, message: "Git download complete"));
}
@@ -245,7 +307,9 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
{
var registry = Registry.LocalMachine;
var key = registry.OpenSubKey(
- @"SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\X64", false);
+ @"SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\X64",
+ false
+ );
if (key != null)
{
var buildId = Convert.ToUInt32(key.GetValue("Bld"));
@@ -254,20 +318,44 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
return;
}
}
-
+
Logger.Info("Downloading VC Redist");
- await downloadService.DownloadToFileAsync(VcRedistDownloadUrl, VcRedistDownloadPath, progress: progress);
- progress?.Report(new ProgressReport(progress: 1f, message: "Visual C++ download complete",
- type: ProgressType.Download));
-
+ await downloadService.DownloadToFileAsync(
+ VcRedistDownloadUrl,
+ VcRedistDownloadPath,
+ progress: progress
+ );
+ progress?.Report(
+ new ProgressReport(
+ progress: 1f,
+ message: "Visual C++ download complete",
+ type: ProgressType.Download
+ )
+ );
+
Logger.Info("Installing VC Redist");
- progress?.Report(new ProgressReport(progress: 0.5f, isIndeterminate: true, type: ProgressType.Generic, message: "Installing prerequisites..."));
- var process = ProcessRunner.StartAnsiProcess(VcRedistDownloadPath, "/install /quiet /norestart");
+ progress?.Report(
+ new ProgressReport(
+ progress: 0.5f,
+ isIndeterminate: true,
+ type: ProgressType.Generic,
+ message: "Installing prerequisites..."
+ )
+ );
+ var process = ProcessRunner.StartAnsiProcess(
+ VcRedistDownloadPath,
+ "/install /quiet /norestart"
+ );
await process.WaitForExitAsync();
- progress?.Report(new ProgressReport(progress: 1f, message: "Visual C++ install complete",
- type: ProgressType.Generic));
-
+ progress?.Report(
+ new ProgressReport(
+ progress: 1f,
+ message: "Visual C++ install complete",
+ type: ProgressType.Generic
+ )
+ );
+
File.Delete(VcRedistDownloadPath);
}
@@ -286,5 +374,4 @@ public class WindowsPrerequisiteHelper : IPrerequisiteHelper
File.Delete(PortableGitDownloadPath);
}
-
}
diff --git a/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs b/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs
index 73fc62e0..1f6f11e8 100644
--- a/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs
+++ b/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs
@@ -113,6 +113,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Clear Selection.
+ ///
+ public static string Action_ClearSelection {
+ get {
+ return ResourceManager.GetString("Action_ClearSelection", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Close.
///
@@ -131,6 +140,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Consolidate.
+ ///
+ public static string Action_Consolidate {
+ get {
+ return ResourceManager.GetString("Action_Consolidate", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Continue.
///
@@ -140,6 +158,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Copy.
+ ///
+ public static string Action_Copy {
+ get {
+ return ResourceManager.GetString("Action_Copy", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Delete.
///
@@ -149,6 +176,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Downgrade.
+ ///
+ public static string Action_Downgrade {
+ get {
+ return ResourceManager.GetString("Action_Downgrade", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Edit.
///
@@ -248,6 +284,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Open in Image Viewer.
+ ///
+ public static string Action_OpenInViewer {
+ get {
+ return ResourceManager.GetString("Action_OpenInViewer", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Open on CivitAI.
///
@@ -284,6 +329,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Refresh.
+ ///
+ public static string Action_Refresh {
+ get {
+ return ResourceManager.GetString("Action_Refresh", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Relaunch.
///
@@ -383,6 +437,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Select All.
+ ///
+ public static string Action_SelectAll {
+ get {
+ return ResourceManager.GetString("Action_SelectAll", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Select Directory.
///
@@ -410,6 +473,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Send to Inference.
+ ///
+ public static string Action_SendToInference {
+ get {
+ return ResourceManager.GetString("Action_SendToInference", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Show in Explorer.
///
@@ -446,6 +518,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Upgrade.
+ ///
+ public static string Action_Upgrade {
+ get {
+ return ResourceManager.GetString("Action_Upgrade", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Yes.
///
@@ -509,6 +590,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Are you sure?.
+ ///
+ public static string Label_AreYouSure {
+ get {
+ return ResourceManager.GetString("Label_AreYouSure", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Automatically scroll to end of console output.
///
@@ -662,6 +752,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to This will move all generated images from the selected packages to the Consolidated directory of the shared outputs folder. This action cannot be undone..
+ ///
+ public static string Label_ConsolidateExplanation {
+ get {
+ return ResourceManager.GetString("Label_ConsolidateExplanation", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Current directory:.
///
@@ -888,7 +987,16 @@ namespace StabilityMatrix.Avalonia.Languages {
}
///
- /// Looks up a localized string similar to Import as Connected.
+ /// Looks up a localized string similar to Image to Image.
+ ///
+ public static string Label_ImageToImage {
+ get {
+ return ResourceManager.GetString("Label_ImageToImage", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Import with Metadata.
///
public static string Label_ImportAsConnected {
get {
@@ -932,6 +1040,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Inpainting.
+ ///
+ public static string Label_Inpainting {
+ get {
+ return ResourceManager.GetString("Label_Inpainting", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Input.
///
@@ -1148,6 +1265,24 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to {0} images selected.
+ ///
+ public static string Label_NumImagesSelected {
+ get {
+ return ResourceManager.GetString("Label_NumImagesSelected", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to 1 image selected.
+ ///
+ public static string Label_OneImageSelected {
+ get {
+ return ResourceManager.GetString("Label_OneImageSelected", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Only available on Windows.
///
@@ -1157,6 +1292,33 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Output Folder.
+ ///
+ public static string Label_OutputFolder {
+ get {
+ return ResourceManager.GetString("Label_OutputFolder", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Output Browser.
+ ///
+ public static string Label_OutputsPageTitle {
+ get {
+ return ResourceManager.GetString("Label_OutputsPageTitle", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Output Type.
+ ///
+ public static string Label_OutputType {
+ get {
+ return ResourceManager.GetString("Label_OutputType", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Package Environment.
///
@@ -1247,6 +1409,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Python Packages.
+ ///
+ public static string Label_PythonPackages {
+ get {
+ return ResourceManager.GetString("Label_PythonPackages", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Python Version Info.
///
@@ -1481,6 +1652,15 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Text to Image.
+ ///
+ public static string Label_TextToImage {
+ get {
+ return ResourceManager.GetString("Label_TextToImage", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Theme.
///
@@ -1526,6 +1706,24 @@ namespace StabilityMatrix.Avalonia.Languages {
}
}
+ ///
+ /// Looks up a localized string similar to Upscale.
+ ///
+ public static string Label_Upscale {
+ get {
+ return ResourceManager.GetString("Label_Upscale", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Output Sharing.
+ ///
+ public static string Label_UseSharedOutputFolder {
+ get {
+ return ResourceManager.GetString("Label_UseSharedOutputFolder", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to VAE.
///
diff --git a/StabilityMatrix.Avalonia/Languages/Resources.resx b/StabilityMatrix.Avalonia/Languages/Resources.resx
index 7575d1d8..f16aba39 100644
--- a/StabilityMatrix.Avalonia/Languages/Resources.resx
+++ b/StabilityMatrix.Avalonia/Languages/Resources.resx
@@ -379,7 +379,7 @@
Drop file here to import
- Import as Connected
+ Import with Metadata
Search for connected metadata on new local imports
@@ -678,7 +678,73 @@
Restore Default Layout
+
+ Output Sharing
+
Batch Index
-
\ No newline at end of file
+
+ Copy
+
+
+ Open in Image Viewer
+
+
+ {0} images selected
+
+
+ Output Folder
+
+
+ Output Type
+
+
+ Clear Selection
+
+
+ Select All
+
+
+ Send to Inference
+
+
+ Text to Image
+
+
+ Image to Image
+
+
+ Inpainting
+
+
+ Upscale
+
+
+ Output Browser
+
+
+ 1 image selected
+
+
+ Python Packages
+
+
+ Consolidate
+
+
+ Are you sure?
+
+
+ This will move all generated images from the selected packages to the Consolidated directory of the shared outputs folder. This action cannot be undone.
+
+
+ Refresh
+
+
+ Upgrade
+
+
+ Downgrade
+
+
diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs
new file mode 100644
index 00000000..5017ab38
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormat.cs
@@ -0,0 +1,72 @@
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+
+namespace StabilityMatrix.Avalonia.Models.Inference;
+
+public record FileNameFormat
+{
+ public string Template { get; }
+
+ public string Prefix { get; set; } = "";
+
+ public string Postfix { get; set; } = "";
+
+ public IReadOnlyList Parts { get; }
+
+ private FileNameFormat(string template, IReadOnlyList parts)
+ {
+ Template = template;
+ Parts = parts;
+ }
+
+ public FileNameFormat WithBatchPostFix(int current, int total)
+ {
+ return this with { Postfix = Postfix + $" ({current}-{total})" };
+ }
+
+ public FileNameFormat WithGridPrefix()
+ {
+ return this with { Prefix = Prefix + "Grid_" };
+ }
+
+ public string GetFileName()
+ {
+ return Prefix
+ + string.Join(
+ "",
+ Parts.Select(
+ part => part.Match(constant => constant, substitution => substitution.Invoke())
+ )
+ )
+ + Postfix;
+ }
+
+ public static FileNameFormat Parse(string template, FileNameFormatProvider provider)
+ {
+ var parts = provider.GetParts(template).ToImmutableArray();
+ return new FileNameFormat(template, parts);
+ }
+
+ public static bool TryParse(
+ string template,
+ FileNameFormatProvider provider,
+ [NotNullWhen(true)] out FileNameFormat? format
+ )
+ {
+ try
+ {
+ format = Parse(template, provider);
+ return true;
+ }
+ catch (ArgumentException)
+ {
+ format = null;
+ return false;
+ }
+ }
+
+ public const string DefaultTemplate = "{date}_{time}-{model_name}-{seed}";
+}
diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs
new file mode 100644
index 00000000..3b17284b
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs
@@ -0,0 +1,7 @@
+using System;
+using OneOf;
+
+namespace StabilityMatrix.Avalonia.Models.Inference;
+
+[GenerateOneOf]
+public partial class FileNameFormatPart : OneOfBase> { }
diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs
new file mode 100644
index 00000000..ff6905fd
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs
@@ -0,0 +1,197 @@
+using System;
+using System.Collections.Generic;
+using System.ComponentModel.DataAnnotations;
+using System.Diagnostics.Contracts;
+using System.IO;
+using System.Linq;
+using System.Text.RegularExpressions;
+using Avalonia.Data;
+using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Models;
+
+namespace StabilityMatrix.Avalonia.Models.Inference;
+
+public partial class FileNameFormatProvider
+{
+ public GenerationParameters? GenerationParameters { get; init; }
+
+ public InferenceProjectType? ProjectType { get; init; }
+
+ public string? ProjectName { get; init; }
+
+ private Dictionary>? _substitutions;
+
+ public Dictionary> Substitutions =>
+ _substitutions ??= new Dictionary>
+ {
+ { "seed", () => GenerationParameters?.Seed.ToString() },
+ { "prompt", () => GenerationParameters?.PositivePrompt },
+ { "negative_prompt", () => GenerationParameters?.NegativePrompt },
+ {
+ "model_name",
+ () => Path.GetFileNameWithoutExtension(GenerationParameters?.ModelName)
+ },
+ { "model_hash", () => GenerationParameters?.ModelHash },
+ { "width", () => GenerationParameters?.Width.ToString() },
+ { "height", () => GenerationParameters?.Height.ToString() },
+ { "project_type", () => ProjectType?.GetStringValue() },
+ { "project_name", () => ProjectName },
+ { "date", () => DateTime.Now.ToString("yyyy-MM-dd") },
+ { "time", () => DateTime.Now.ToString("HH-mm-ss") }
+ };
+
+ ///
+ /// Validate a format string
+ ///
+ /// Format string
+ /// Thrown if the format string contains an unknown variable
+ [Pure]
+ public ValidationResult Validate(string format)
+ {
+ var regex = BracketRegex();
+ var matches = regex.Matches(format);
+ var variables = matches.Select(m => m.Groups[1].Value);
+
+ foreach (var variableText in variables)
+ {
+ try
+ {
+ var (variable, _) = ExtractVariableAndSlice(variableText);
+
+ if (!Substitutions.ContainsKey(variable))
+ {
+ return new ValidationResult($"Unknown variable '{variable}'");
+ }
+ }
+ catch (Exception e)
+ {
+ return new ValidationResult($"Invalid variable '{variableText}': {e.Message}");
+ }
+ }
+
+ return ValidationResult.Success!;
+ }
+
+ public IEnumerable GetParts(string template)
+ {
+ var regex = BracketRegex();
+ var matches = regex.Matches(template);
+
+ var parts = new List();
+
+ // Loop through all parts of the string, including matches and non-matches
+ var currentIndex = 0;
+
+ foreach (var result in matches.Cast())
+ {
+ // If the match is not at the start of the string, add a constant part
+ if (result.Index != currentIndex)
+ {
+ var constant = template[currentIndex..result.Index];
+ parts.Add(constant);
+
+ currentIndex += constant.Length;
+ }
+
+ // Now we're at start of the current match, add the variable part
+ var (variable, slice) = ExtractVariableAndSlice(result.Groups[1].Value);
+ var substitution = Substitutions[variable];
+
+ // Slice string if necessary
+ if (slice is not null)
+ {
+ parts.Add(
+ (FileNameFormatPart)(
+ () =>
+ {
+ var value = substitution();
+ if (value is null)
+ return null;
+
+ if (slice.End is null)
+ {
+ value = value[(slice.Start ?? 0)..];
+ }
+ else
+ {
+ var length =
+ Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0);
+ value = value.Substring(slice.Start ?? 0, length);
+ }
+
+ return value;
+ }
+ )
+ );
+ }
+ else
+ {
+ parts.Add(substitution);
+ }
+
+ currentIndex += result.Length;
+ }
+
+ // Add remaining as constant
+ if (currentIndex != template.Length)
+ {
+ var constant = template[currentIndex..];
+ parts.Add(constant);
+ }
+
+ return parts;
+ }
+
+ ///
+ /// Return a sample provider for UI preview
+ ///
+ public static FileNameFormatProvider GetSample()
+ {
+ return new FileNameFormatProvider
+ {
+ GenerationParameters = GenerationParameters.GetSample(),
+ ProjectType = InferenceProjectType.TextToImage,
+ ProjectName = "Sample Project"
+ };
+ }
+
+ ///
+ /// Extract variable and index from a combined string
+ ///
+ private static (string Variable, Slice? Slice) ExtractVariableAndSlice(string combined)
+ {
+ if (IndexRegex().Matches(combined).FirstOrDefault() is not { Success: true } match)
+ {
+ return (combined, null);
+ }
+
+ // Variable is everything before the match
+ var variable = combined[..match.Groups[0].Index];
+
+ var start = match.Groups["start"].Value;
+ var end = match.Groups["end"].Value;
+ var step = match.Groups["step"].Value;
+
+ var slice = new Slice(
+ string.IsNullOrEmpty(start) ? null : int.Parse(start),
+ string.IsNullOrEmpty(end) ? null : int.Parse(end),
+ string.IsNullOrEmpty(step) ? null : int.Parse(step)
+ );
+
+ return (variable, slice);
+ }
+
+ ///
+ /// Regex for matching contents within a curly brace.
+ ///
+ [GeneratedRegex(@"\{([a-z_:\d\[\]]+)\}")]
+ private static partial Regex BracketRegex();
+
+ ///
+ /// Regex for matching a Python-like array index.
+ ///
+ [GeneratedRegex(@"\[(?:(?-?\d+)?)\:(?:(?-?\d+)?)?(?:\:(?-?\d+))?\]")]
+ private static partial Regex IndexRegex();
+
+ private record Slice(int? Start, int? End, int? Step);
+}
diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs
new file mode 100644
index 00000000..a453b3bc
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatVar.cs
@@ -0,0 +1,8 @@
+namespace StabilityMatrix.Avalonia.Models.Inference;
+
+public record FileNameFormatVar
+{
+ public required string Variable { get; init; }
+
+ public string? Example { get; init; }
+}
diff --git a/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs b/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs
index c4c76fa4..4c18bd81 100644
--- a/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs
+++ b/StabilityMatrix.Avalonia/Models/Inference/Prompt.cs
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
-using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.IO;
diff --git a/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs b/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs
index 5c782de4..e361207b 100644
--- a/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs
+++ b/StabilityMatrix.Avalonia/Models/Inference/StackExpanderModel.cs
@@ -1,5 +1,4 @@
using System.Text.Json.Serialization;
-using StabilityMatrix.Avalonia.ViewModels.Inference;
namespace StabilityMatrix.Avalonia.Models.Inference;
diff --git a/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs b/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs
index f736fdd8..689fe144 100644
--- a/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs
+++ b/StabilityMatrix.Avalonia/Models/Inference/ViewState.cs
@@ -1,5 +1,4 @@
-using System.Text.Json.Nodes;
-using System.Text.Json.Serialization;
+using System.Text.Json.Serialization;
namespace StabilityMatrix.Avalonia.Models.Inference;
diff --git a/StabilityMatrix.Avalonia/Models/PackageOutputCategory.cs b/StabilityMatrix.Avalonia/Models/PackageOutputCategory.cs
new file mode 100644
index 00000000..2318d0f8
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Models/PackageOutputCategory.cs
@@ -0,0 +1,7 @@
+namespace StabilityMatrix.Avalonia.Models;
+
+public class PackageOutputCategory
+{
+ public required string Name { get; set; }
+ public required string Path { get; set; }
+}
diff --git a/StabilityMatrix.Avalonia/Models/SharedState.cs b/StabilityMatrix.Avalonia/Models/SharedState.cs
index d2dd8fe1..17fd3288 100644
--- a/StabilityMatrix.Avalonia/Models/SharedState.cs
+++ b/StabilityMatrix.Avalonia/Models/SharedState.cs
@@ -1,14 +1,17 @@
using CommunityToolkit.Mvvm.ComponentModel;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Models;
///
/// Singleton DI service for observable shared UI state.
///
+[Singleton]
public partial class SharedState : ObservableObject
{
///
/// Whether debug mode enabled from settings page version tap.
///
- [ObservableProperty] private bool isDebugMode;
+ [ObservableProperty]
+ private bool isDebugMode;
}
diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs
index 1b4395a4..5b03207d 100644
--- a/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs
+++ b/StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
-using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text.RegularExpressions;
@@ -16,6 +15,7 @@ using NLog;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Services;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
@@ -26,6 +26,7 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
+[Singleton(typeof(ICompletionProvider))]
public partial class CompletionProvider : ICompletionProvider
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs
index c93df5fa..fea2490c 100644
--- a/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs
+++ b/StabilityMatrix.Avalonia/Models/TagCompletion/TextCompletionRequest.cs
@@ -1,6 +1,4 @@
-using System.Collections.Generic;
-using AvaloniaEdit.Document;
-using StabilityMatrix.Core.Models.Tokens;
+using StabilityMatrix.Core.Models.Tokens;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
diff --git a/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs b/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs
index 74b22d44..2181e5df 100644
--- a/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs
+++ b/StabilityMatrix.Avalonia/Models/TagCompletion/TokenizerProvider.cs
@@ -1,10 +1,12 @@
using System.Diagnostics.CodeAnalysis;
using StabilityMatrix.Avalonia.Extensions;
+using StabilityMatrix.Core.Attributes;
using TextMateSharp.Grammars;
using TextMateSharp.Registry;
namespace StabilityMatrix.Avalonia.Models.TagCompletion;
+[Singleton(typeof(ITokenizerProvider))]
public class TokenizerProvider : ITokenizerProvider
{
private readonly Registry registry = new(new RegistryOptions(ThemeName.DarkPlus));
diff --git a/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs b/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs
index 3caf0b86..50416643 100644
--- a/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs
+++ b/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs
@@ -3,5 +3,6 @@
public enum TextEditorPreset
{
None,
- Prompt
+ Prompt,
+ Console
}
diff --git a/StabilityMatrix.Avalonia/Program.cs b/StabilityMatrix.Avalonia/Program.cs
index 40c95d08..899f9fc7 100644
--- a/StabilityMatrix.Avalonia/Program.cs
+++ b/StabilityMatrix.Avalonia/Program.cs
@@ -7,6 +7,7 @@ using System.Reflection;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
+using AsyncAwaitBestPractices;
using AsyncImageLoader;
using Avalonia;
using Avalonia.Controls;
@@ -232,10 +233,16 @@ public class Program
if (e.ExceptionObject is not Exception ex)
return;
- Logger.Fatal(ex, "Unhandled {Type}: {Message}", ex.GetType().Name, ex.Message);
+ // Exception automatically logged by Sentry if enabled
if (SentrySdk.IsEnabled)
{
+ ex.SetSentryMechanism("AppDomain.UnhandledException", handled: false);
SentrySdk.CaptureException(ex);
+ SentrySdk.FlushAsync().SafeFireAndForget();
+ }
+ else
+ {
+ Logger.Fatal(ex, "Unhandled {Type}: {Message}", ex.GetType().Name, ex.Message);
}
if (
@@ -290,6 +297,10 @@ public class Program
[DoesNotReturn]
private static void ExitWithException(Exception exception)
{
+ if (SentrySdk.IsEnabled)
+ {
+ SentrySdk.Flush();
+ }
App.Shutdown(1);
Dispatcher.UIThread.InvokeShutdown();
Environment.Exit(Marshal.GetHRForException(exception));
diff --git a/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
index bc6afaff..f40990fa 100644
--- a/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
+++ b/StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
@@ -1,5 +1,4 @@
using System;
-using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
diff --git a/StabilityMatrix.Avalonia/Services/INotificationService.cs b/StabilityMatrix.Avalonia/Services/INotificationService.cs
index 77a4d9d2..d3cc7e77 100644
--- a/StabilityMatrix.Avalonia/Services/INotificationService.cs
+++ b/StabilityMatrix.Avalonia/Services/INotificationService.cs
@@ -2,6 +2,8 @@
using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls.Notifications;
+using Microsoft.Extensions.Logging;
+using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Services;
@@ -11,7 +13,8 @@ public interface INotificationService
public void Initialize(
Visual? visual,
NotificationPosition position = NotificationPosition.BottomRight,
- int maxItems = 3);
+ int maxItems = 3
+ );
public void Show(INotification notification);
@@ -26,7 +29,8 @@ public interface INotificationService
Task task,
string title = "Error",
string? message = null,
- NotificationType appearance = NotificationType.Error);
+ NotificationType appearance = NotificationType.Error
+ );
///
/// Attempt to run the given void task, showing a generic error notification if it fails.
@@ -40,16 +44,18 @@ public interface INotificationService
Task task,
string title = "Error",
string? message = null,
- NotificationType appearance = NotificationType.Error);
+ NotificationType appearance = NotificationType.Error
+ );
///
/// Show a notification with the given parameters.
///
void Show(
- string title,
+ string title,
string message,
NotificationType appearance = NotificationType.Information,
- TimeSpan? expiration = null);
+ TimeSpan? expiration = null
+ );
///
/// Show a notification that will not auto-dismiss.
@@ -60,5 +66,15 @@ public interface INotificationService
void ShowPersistent(
string title,
string message,
- NotificationType appearance = NotificationType.Information);
+ NotificationType appearance = NotificationType.Information
+ );
+
+ ///
+ /// Show a notification for a that will not auto-dismiss.
+ ///
+ void ShowPersistent(
+ AppException exception,
+ NotificationType appearance = NotificationType.Error,
+ LogLevel logLevel = LogLevel.Warning
+ );
}
diff --git a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
index 3e71a877..8254e467 100644
--- a/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
+++ b/StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
@@ -13,6 +13,7 @@ using SkiaSharp;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Core.Api;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
@@ -27,6 +28,7 @@ namespace StabilityMatrix.Avalonia.Services;
/// Manager for the current inference client
/// Has observable shared properties for shared info like model names
///
+[Singleton(typeof(IInferenceClientManager))]
public partial class InferenceClientManager : ObservableObject, IInferenceClientManager
{
private readonly ILogger logger;
@@ -345,6 +347,44 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
}
+ private async Task MigrateLinksIfNeeded(PackagePair packagePair)
+ {
+ if (packagePair.InstalledPackage.FullPath is not { } packagePath)
+ {
+ throw new ArgumentException("Package path is null", nameof(packagePair));
+ }
+
+ var inferenceDir = settingsManager.ImagesInferenceDirectory;
+ inferenceDir.Create();
+
+ // For locally installed packages only
+ // Delete ./output/Inference
+
+ var legacyInferenceLinkDir = new DirectoryPath(
+ packagePair.InstalledPackage.FullPath
+ ).JoinDir("output", "Inference");
+
+ if (legacyInferenceLinkDir.Exists)
+ {
+ logger.LogInformation(
+ "Deleting legacy inference link at {LegacyDir}",
+ legacyInferenceLinkDir
+ );
+
+ if (legacyInferenceLinkDir.IsSymbolicLink)
+ {
+ await legacyInferenceLinkDir.DeleteAsync(false);
+ }
+ else
+ {
+ logger.LogWarning(
+ "Legacy inference link at {LegacyDir} is not a symbolic link, skipping",
+ legacyInferenceLinkDir
+ );
+ }
+ }
+ }
+
///
public async Task ConnectAsync(
PackagePair packagePair,
@@ -367,11 +407,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
logger.LogError(ex, "Error setting up completion provider");
});
- // Setup image folder links
- await comfyPackage.SetupInferenceOutputFolderLinks(
- packagePair.InstalledPackage.FullPath
- ?? throw new InvalidOperationException("Package does not have a Path")
- );
+ await MigrateLinksIfNeeded(packagePair);
// Get user defined host and port
var host = packagePair.InstalledPackage.GetLaunchArgsHost();
diff --git a/StabilityMatrix.Avalonia/Services/NavigationService.cs b/StabilityMatrix.Avalonia/Services/NavigationService.cs
index 4cc47881..b9fb1952 100644
--- a/StabilityMatrix.Avalonia/Services/NavigationService.cs
+++ b/StabilityMatrix.Avalonia/Services/NavigationService.cs
@@ -1,6 +1,4 @@
using System;
-using System.Collections.Generic;
-using System.Diagnostics;
using System.Linq;
using FluentAvalonia.UI.Controls;
using FluentAvalonia.UI.Media.Animation;
@@ -8,10 +6,12 @@ using FluentAvalonia.UI.Navigation;
using StabilityMatrix.Avalonia.Animations;
using StabilityMatrix.Avalonia.ViewModels;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.Services;
+[Singleton(typeof(INavigationService))]
public class NavigationService : INavigationService
{
private Frame? _frame;
diff --git a/StabilityMatrix.Avalonia/Services/NotificationService.cs b/StabilityMatrix.Avalonia/Services/NotificationService.cs
index 1bb664b9..98611f8d 100644
--- a/StabilityMatrix.Avalonia/Services/NotificationService.cs
+++ b/StabilityMatrix.Avalonia/Services/NotificationService.cs
@@ -3,20 +3,32 @@ using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
+using Microsoft.Extensions.Logging;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Services;
+[Singleton(typeof(INotificationService))]
public class NotificationService : INotificationService
{
+ private readonly ILogger logger;
private WindowNotificationManager? notificationManager;
-
+
+ public NotificationService(ILogger logger)
+ {
+ this.logger = logger;
+ }
+
public void Initialize(
- Visual? visual,
+ Visual? visual,
NotificationPosition position = NotificationPosition.BottomRight,
- int maxItems = 4)
+ int maxItems = 4
+ )
{
- if (notificationManager is not null) return;
+ if (notificationManager is not null)
+ return;
notificationManager = new WindowNotificationManager(TopLevel.GetTopLevel(visual))
{
Position = position,
@@ -30,28 +42,44 @@ public class NotificationService : INotificationService
}
public void Show(
- string title,
+ string title,
string message,
NotificationType appearance = NotificationType.Information,
- TimeSpan? expiration = null)
+ TimeSpan? expiration = null
+ )
{
Show(new Notification(title, message, appearance, expiration));
}
public void ShowPersistent(
- string title,
+ string title,
string message,
- NotificationType appearance = NotificationType.Information)
+ NotificationType appearance = NotificationType.Information
+ )
{
Show(new Notification(title, message, appearance, TimeSpan.Zero));
}
-
+
+ ///
+ public void ShowPersistent(
+ AppException exception,
+ NotificationType appearance = NotificationType.Warning,
+ LogLevel logLevel = LogLevel.Warning
+ )
+ {
+ // Log exception
+ logger.Log(logLevel, exception, "{Message}", exception.Message);
+
+ Show(new Notification(exception.Message, exception.Details, appearance, TimeSpan.Zero));
+ }
+
///
public async Task> TryAsync(
Task task,
string title = "Error",
string? message = null,
- NotificationType appearance = NotificationType.Error)
+ NotificationType appearance = NotificationType.Error
+ )
{
try
{
@@ -63,13 +91,14 @@ public class NotificationService : INotificationService
return TaskResult.FromException(e);
}
}
-
+
///
public async Task> TryAsync(
Task task,
string title = "Error",
string? message = null,
- NotificationType appearance = NotificationType.Error)
+ NotificationType appearance = NotificationType.Error
+ )
{
try
{
diff --git a/StabilityMatrix.Avalonia/Services/ServiceManager.cs b/StabilityMatrix.Avalonia/Services/ServiceManager.cs
index 26d604f4..4b21c294 100644
--- a/StabilityMatrix.Avalonia/Services/ServiceManager.cs
+++ b/StabilityMatrix.Avalonia/Services/ServiceManager.cs
@@ -256,4 +256,19 @@ public class ServiceManager
return new BetterContentDialog { Content = view };
}
+
+ public void Register(Type type, Func providerFunc)
+ {
+ lock (providers)
+ {
+ if (instances.ContainsKey(type) || providers.ContainsKey(type))
+ {
+ throw new ArgumentException(
+ $"Service of type {type} is already registered for {typeof(T)}"
+ );
+ }
+
+ providers[type] = providerFunc;
+ }
+ }
}
diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
index 648330dd..4ceea5b9 100644
--- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
+++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
@@ -8,7 +8,7 @@
app.manifest
true
./Assets/Icon.ico
- 2.5.5-dev.1
+ 2.6.0-dev.1
$(Version)
true
true
@@ -23,46 +23,48 @@
-
-
-
+
+
+
-
+
-
+
-
+
-
-
+
+
-
+
-
+
-
-
-
+
+
+
+
-
+
-
-
+
+
+
-
+
diff --git a/StabilityMatrix.Avalonia/Styles/ThemeColors.cs b/StabilityMatrix.Avalonia/Styles/ThemeColors.cs
index 5df260c2..c59b9d53 100644
--- a/StabilityMatrix.Avalonia/Styles/ThemeColors.cs
+++ b/StabilityMatrix.Avalonia/Styles/ThemeColors.cs
@@ -1,5 +1,4 @@
-using Avalonia;
-using Avalonia.Media;
+using Avalonia.Media;
namespace StabilityMatrix.Avalonia.Styles;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs
index b905bba6..a4ed0309 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs
@@ -1,10 +1,14 @@
using System;
+using CommunityToolkit.Mvvm.ComponentModel;
using FluentAvalonia.UI.Controls;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
-public class ContentDialogProgressViewModelBase : ConsoleProgressViewModel
+public partial class ContentDialogProgressViewModelBase : ConsoleProgressViewModel
{
+ [ObservableProperty]
+ private bool hideCloseButton;
+
public event EventHandler? PrimaryButtonClick;
public event EventHandler? SecondaryButtonClick;
public event EventHandler? CloseButtonClick;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
index 3bd7e614..f1d0d11c 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs
@@ -3,11 +3,13 @@ using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.IO;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
+using Avalonia.Controls.Notifications;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.Input;
using NLog;
@@ -27,6 +29,8 @@ using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
@@ -41,6 +45,7 @@ public abstract partial class InferenceGenerationViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
+ private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService;
private readonly ServiceManager vmFactory;
@@ -60,11 +65,13 @@ public abstract partial class InferenceGenerationViewModelBase
protected InferenceGenerationViewModelBase(
ServiceManager vmFactory,
IInferenceClientManager inferenceClientManager,
- INotificationService notificationService
+ INotificationService notificationService,
+ ISettingsManager settingsManager
)
: base(notificationService)
{
this.notificationService = notificationService;
+ this.settingsManager = settingsManager;
this.vmFactory = vmFactory;
ClientManager = inferenceClientManager;
@@ -75,6 +82,101 @@ public abstract partial class InferenceGenerationViewModelBase
GenerateImageCommand.WithConditionalNotificationErrorHandler(notificationService);
}
+ ///
+ /// Write an image to the default output folder
+ ///
+ protected Task WriteOutputImageAsync(
+ Stream imageStream,
+ ImageGenerationEventArgs args,
+ int batchNum = 0,
+ int batchTotal = 0,
+ bool isGrid = false
+ )
+ {
+ var defaultOutputDir = settingsManager.ImagesInferenceDirectory;
+ defaultOutputDir.Create();
+
+ return WriteOutputImageAsync(
+ imageStream,
+ defaultOutputDir,
+ args,
+ batchNum,
+ batchTotal,
+ isGrid
+ );
+ }
+
+ ///
+ /// Write an image to an output folder
+ ///
+ protected async Task WriteOutputImageAsync(
+ Stream imageStream,
+ DirectoryPath outputDir,
+ ImageGenerationEventArgs args,
+ int batchNum = 0,
+ int batchTotal = 0,
+ bool isGrid = false
+ )
+ {
+ var formatTemplateStr = settingsManager.Settings.InferenceOutputImageFileNameFormat;
+
+ var formatProvider = new FileNameFormatProvider
+ {
+ GenerationParameters = args.Parameters,
+ ProjectType = args.Project?.ProjectType,
+ ProjectName = ProjectFile?.NameWithoutExtension
+ };
+
+ // Parse to format
+ if (
+ string.IsNullOrEmpty(formatTemplateStr)
+ || !FileNameFormat.TryParse(formatTemplateStr, formatProvider, out var format)
+ )
+ {
+ // Fallback to default
+ Logger.Warn(
+ "Failed to parse format template: {FormatTemplate}, using default",
+ formatTemplateStr
+ );
+
+ format = FileNameFormat.Parse(FileNameFormat.DefaultTemplate, formatProvider);
+ }
+
+ if (isGrid)
+ {
+ format = format.WithGridPrefix();
+ }
+
+ if (batchNum >= 1 && batchTotal > 1)
+ {
+ format = format.WithBatchPostFix(batchNum, batchTotal);
+ }
+
+ var fileName = format.GetFileName();
+ var file = outputDir.JoinFile($"{fileName}.png");
+
+ // Until the file is free, keep adding _{i} to the end
+ for (var i = 0; i < 100; i++)
+ {
+ if (!file.Exists)
+ break;
+
+ file = outputDir.JoinFile($"{fileName}_{i + 1}.png");
+ }
+
+ // If that fails, append an 7-char uuid
+ if (file.Exists)
+ {
+ var uuid = Guid.NewGuid().ToString("N")[..7];
+ file = outputDir.JoinFile($"{fileName}_{uuid}.png");
+ }
+
+ await using var fileStream = file.Info.OpenWrite();
+ await imageStream.CopyToAsync(fileStream);
+
+ return file;
+ }
+
///
/// Builds the image generation prompt
///
@@ -156,7 +258,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Wait for prompt to finish
await promptTask.Task.WaitAsync(cancellationToken);
- Logger.Trace($"Prompt task {promptTask.Id} finished");
+ Logger.Debug($"Prompt task {promptTask.Id} finished");
// Get output images
var imageOutputs = await client.GetImagesForExecutedPromptAsync(
@@ -164,6 +266,20 @@ public abstract partial class InferenceGenerationViewModelBase
cancellationToken
);
+ if (
+ !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images)
+ || images is not { Count: > 0 }
+ )
+ {
+ // No images match
+ notificationService.Show(
+ "No output",
+ "Did not receive any output images",
+ NotificationType.Warning
+ );
+ return;
+ }
+
// Disable cancellation
await promptInterrupt.DisposeAsync();
@@ -172,15 +288,6 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGalleryCardViewModel.ImageSources.Clear();
}
- if (
- !imageOutputs.TryGetValue(args.OutputNodeNames[0], out var images) || images is null
- )
- {
- // No images match
- notificationService.Show("No output", "Did not receive any output images");
- return;
- }
-
await ProcessOutputImages(images, args);
}
finally
@@ -207,19 +314,22 @@ public abstract partial class InferenceGenerationViewModelBase
ImageGenerationEventArgs args
)
{
+ var client = args.Client;
+
// Write metadata to images
+ var outputImagesBytes = new List();
var outputImages = new List();
- foreach (
- var (i, filePath) in images
- .Select(image => image.ToFilePath(args.Client.OutputImagesDir!))
- .Enumerate()
- )
+
+ foreach (var (i, comfyImage) in images.Enumerate())
{
- if (!filePath.Exists)
- {
- Logger.Warn($"Image file {filePath} does not exist");
- continue;
- }
+ Logger.Debug("Downloading image: {FileName}", comfyImage.FileName);
+ var imageStream = await client.GetImageStreamAsync(comfyImage);
+
+ using var ms = new MemoryStream();
+ await imageStream.CopyToAsync(ms);
+
+ var imageArray = ms.ToArray();
+ outputImagesBytes.Add(imageArray);
var parameters = args.Parameters!;
var project = args.Project!;
@@ -248,17 +358,15 @@ public abstract partial class InferenceGenerationViewModelBase
);
}
- var bytesWithMetadata = PngDataHelper.AddMetadata(
- await filePath.ReadAllBytesAsync(),
- parameters,
- project
- );
+ var bytesWithMetadata = PngDataHelper.AddMetadata(imageArray, parameters, project);
- await using (var outputStream = filePath.Info.OpenWrite())
- {
- await outputStream.WriteAsync(bytesWithMetadata);
- await outputStream.FlushAsync();
- }
+ // Write using generated name
+ var filePath = await WriteOutputImageAsync(
+ new MemoryStream(bytesWithMetadata),
+ args,
+ i + 1,
+ images.Count
+ );
outputImages.Add(new ImageSource(filePath));
@@ -268,17 +376,7 @@ public abstract partial class InferenceGenerationViewModelBase
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
- var outputDir = outputImages[0].LocalFile!.Directory;
-
- var loadedImages = outputImages
- .Select(i => i.LocalFile)
- .Where(f => f is { Exists: true })
- .Select(f =>
- {
- using var stream = f!.Info.OpenRead();
- return SKImage.FromEncodedData(stream);
- })
- .ToImmutableArray();
+ var loadedImages = outputImagesBytes.Select(SKImage.FromEncodedData).ToImmutableArray();
var project = args.Project!;
@@ -297,13 +395,11 @@ public abstract partial class InferenceGenerationViewModelBase
);
// Save to disk
- var lastName = outputImages.Last().LocalFile?.Info.Name;
- var gridPath = outputDir!.JoinFile($"grid-{lastName}");
-
- await using (var fileStream = gridPath.Info.OpenWrite())
- {
- await fileStream.WriteAsync(gridBytesWithMetadata);
- }
+ var gridPath = await WriteOutputImageAsync(
+ new MemoryStream(gridBytesWithMetadata),
+ args,
+ isGrid: true
+ );
// Insert to start of images
var gridImage = new ImageSource(gridPath);
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs
index b31b5501..bc1d60ff 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs
@@ -310,7 +310,7 @@ public abstract partial class InferenceTabViewModelBase
if (this is IImageGalleryComponent imageGalleryComponent)
{
imageGalleryComponent.LoadImagesToGallery(
- new ImageSource(imageFile.GlobalFullPath)
+ new ImageSource(imageFile.AbsolutePath)
);
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs
index e4e458d8..0d06e87f 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs
@@ -1,5 +1,4 @@
using System;
-using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
namespace StabilityMatrix.Avalonia.ViewModels.Base;
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs
index a4785ae6..43288d83 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs
@@ -17,6 +17,7 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.Views.Dialogs;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api;
@@ -27,6 +28,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
+[ManagedService]
+[Transient]
public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -254,16 +257,17 @@ public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel
private static string PruneDescription(CivitModel model)
{
- var prunedDescription = model.Description
- .Replace("
", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("
", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("
", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
- .Replace("", $"{Environment.NewLine}{Environment.NewLine}");
+ var prunedDescription =
+ model.Description
+ ?.Replace("
", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("
", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}")
+ .Replace("", $"{Environment.NewLine}{Environment.NewLine}") ?? string.Empty;
prunedDescription = HtmlRegex().Replace(prunedDescription, string.Empty);
return prunedDescription;
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs
index 2ecfc90e..e7659258 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs
@@ -5,16 +5,13 @@ using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using System.Net.Http;
-using System.Reactive;
using System.Reactive.Linq;
-using System.Reactive.Threading.Tasks;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Collections;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
-using AvaloniaEdit.Utils;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
@@ -42,6 +39,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(CheckpointBrowserPage))]
+[Singleton]
public partial class CheckpointBrowserViewModel : PageViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs
index f1260f93..a978d874 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs
@@ -9,7 +9,9 @@ using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using NLog;
+using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
@@ -19,6 +21,8 @@ using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
+[ManagedService]
+[Transient]
public partial class CheckpointFile : ViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -257,6 +261,11 @@ public partial class CheckpointFile : ViewModelBase
.Where(File.Exists)
.FirstOrDefault();
+ if (string.IsNullOrWhiteSpace(checkpointFile.PreviewImagePath))
+ {
+ checkpointFile.PreviewImagePath = Assets.NoImage.ToString();
+ }
+
yield return checkpointFile;
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
index 398aaf64..ecff1c02 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
@@ -17,6 +17,7 @@ using DynamicData.Binding;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
@@ -27,6 +28,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
+[ManagedService]
+[Transient]
public partial class CheckpointFolder : ViewModelBase
{
private readonly ISettingsManager settingsManager;
@@ -99,7 +102,7 @@ public partial class CheckpointFolder : ViewModelBase
public IObservableCollection CheckpointFiles { get; } =
new ObservableCollectionExtended();
- public IObservableCollection DisplayedCheckpointFiles { get; } =
+ public IObservableCollection DisplayedCheckpointFiles { get; set; } =
new ObservableCollectionExtended();
public CheckpointFolder(
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
index c77b92d2..86b7fdd3 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
@@ -24,6 +24,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(CheckpointsPage))]
+[Singleton]
public partial class CheckpointsPageViewModel : PageViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs
index aaf2aaa0..47625eaf 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadResourceViewModel.cs
@@ -16,6 +16,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(DownloadResourceDialog))]
+[ManagedService]
+[Transient]
public partial class DownloadResourceViewModel : ContentDialogViewModelBase
{
private readonly IDownloadService downloadService;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs
index 55fa73b5..cab1a11a 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs
@@ -12,6 +12,8 @@ using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(EnvVarsViewModel))]
+[ManagedService]
+[Transient]
public partial class EnvVarsViewModel : ContentDialogViewModelBase
{
[ObservableProperty]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs
index 5c0b9427..42e8a6d6 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ExceptionViewModel.cs
@@ -6,11 +6,13 @@ using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(ExceptionDialog))]
+[ManagedService]
+[Transient]
public partial class ExceptionViewModel : ViewModelBase
{
public Exception? Exception { get; set; }
-
+
public string? Message => Exception?.Message;
-
+
public string? ExceptionType => Exception?.GetType().Name ?? "";
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs
index 821d8f5b..c07824e7 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageViewerViewModel.cs
@@ -21,6 +21,8 @@ using Size = StabilityMatrix.Core.Helper.Size;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(ImageViewerDialog))]
+[ManagedService]
+[Transient]
public partial class ImageViewerViewModel : ContentDialogViewModelBase
{
[ObservableProperty]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs
index ac0d1dc1..351f9a64 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InferenceConnectionHelpViewModel.cs
@@ -6,7 +6,6 @@ using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
-using StabilityMatrix.Avalonia.Animations;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Services;
@@ -23,6 +22,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(InferenceConnectionHelpDialog))]
+[ManagedService]
+[Transient]
public partial class InferenceConnectionHelpViewModel : ContentDialogViewModelBase
{
private readonly ISettingsManager settingsManager;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
index 33f92b9d..d156def6 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
@@ -19,6 +19,7 @@ using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models;
@@ -31,11 +32,14 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
+[ManagedService]
+[Transient]
public partial class InstallerViewModel : ContentDialogViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ISettingsManager settingsManager;
+ private readonly IPackageFactory packageFactory;
private readonly IPyRunner pyRunner;
private readonly IDownloadService downloadService;
private readonly INotificationService notificationService;
@@ -47,15 +51,15 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
[ObservableProperty]
private PackageVersion? selectedVersion;
- [ObservableProperty]
- private IReadOnlyList? availablePackages;
-
[ObservableProperty]
private ObservableCollection? availableCommits;
[ObservableProperty]
private ObservableCollection? availableVersions;
+ [ObservableProperty]
+ private ObservableCollection availablePackages;
+
[ObservableProperty]
private GitCommit? selectedCommit;
@@ -69,6 +73,9 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
[NotifyPropertyChangedFor(nameof(CanInstall))]
private bool showDuplicateWarning;
+ [ObservableProperty]
+ private bool showIncompatiblePackages;
+
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(CanInstall))]
private string? installName;
@@ -115,7 +122,6 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
public bool CanInstall =>
!string.IsNullOrWhiteSpace(InstallName) && !ShowDuplicateWarning && !IsLoading;
- public ProgressViewModel InstallProgress { get; } = new();
public IEnumerable Steps { get; set; }
public InstallerViewModel(
@@ -128,22 +134,25 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
)
{
this.settingsManager = settingsManager;
+ this.packageFactory = packageFactory;
this.pyRunner = pyRunner;
this.downloadService = downloadService;
this.notificationService = notificationService;
this.prerequisiteHelper = prerequisiteHelper;
- // AvailablePackages and SelectedPackage
+ var filtered = packageFactory.GetAllAvailablePackages().Where(p => p.IsCompatible).ToList();
+
AvailablePackages = new ObservableCollection(
- packageFactory.GetAllAvailablePackages()
+ filtered.Any() ? filtered : packageFactory.GetAllAvailablePackages()
);
- SelectedPackage = AvailablePackages[0];
+ ShowIncompatiblePackages = !filtered.Any();
}
public override void OnLoaded()
{
if (AvailablePackages == null)
return;
+
IsReleaseMode = !SelectedPackage.ShouldIgnoreReleases;
}
@@ -238,6 +247,8 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
downloadOptions.VersionTag =
SelectedVersion?.TagName
?? throw new NullReferenceException("Selected version is null");
+ downloadOptions.IsLatest =
+ AvailableVersions?.First().TagName == downloadOptions.VersionTag;
installedVersion.InstalledReleaseVersion = downloadOptions.VersionTag;
}
@@ -245,6 +256,11 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
{
downloadOptions.CommitHash =
SelectedCommit?.Sha ?? throw new NullReferenceException("Selected commit is null");
+ downloadOptions.BranchName =
+ SelectedVersion?.TagName
+ ?? throw new NullReferenceException("Selected version is null");
+ downloadOptions.IsLatest = AvailableCommits?.First().Sha == SelectedCommit.Sha;
+
installedVersion.InstalledBranch =
SelectedVersion?.TagName
?? throw new NullReferenceException("Selected version is null");
@@ -259,6 +275,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
var installStep = new InstallPackageStep(
SelectedPackage,
SelectedTorchVersion,
+ downloadOptions,
installLocation
);
var setupModelFoldersStep = new SetupModelFoldersStep(
@@ -301,12 +318,29 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
OnCloseButtonClick();
}
+ partial void OnShowIncompatiblePackagesChanged(bool value)
+ {
+ var filtered = packageFactory
+ .GetAllAvailablePackages()
+ .Where(p => ShowIncompatiblePackages || p.IsCompatible)
+ .ToList();
+
+ AvailablePackages = new ObservableCollection(
+ filtered.Any() ? filtered : packageFactory.GetAllAvailablePackages()
+ );
+ SelectedPackage = AvailablePackages[0];
+ }
+
private void UpdateSelectedVersionToLatestMain()
{
if (AvailableVersions is null)
{
SelectedVersion = null;
}
+ else if (SelectedPackage is FooocusMre)
+ {
+ SelectedVersion = AvailableVersions.FirstOrDefault(x => x.TagName == "moonride-main");
+ }
else
{
// First try to find master
@@ -358,41 +392,33 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
// When changing branch / release modes, refresh
// ReSharper disable once UnusedParameterInPartialMethod
- partial void OnSelectedVersionTypeChanged(PackageVersionType value) =>
- OnSelectedPackageChanged(SelectedPackage);
-
- partial void OnSelectedPackageChanged(BasePackage value)
+ partial void OnSelectedVersionTypeChanged(PackageVersionType value)
{
- IsLoading = true;
- ReleaseNotes = string.Empty;
- AvailableVersions?.Clear();
- AvailableCommits?.Clear();
-
- AvailableVersionTypes = SelectedPackage.ShouldIgnoreReleases
- ? PackageVersionType.Commit
- : PackageVersionType.GithubRelease | PackageVersionType.Commit;
- SelectedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod;
- SelectedTorchVersion = SelectedPackage.GetRecommendedTorchVersion();
- if (Design.IsDesignMode)
+ if (SelectedPackage is null || Design.IsDesignMode)
return;
Dispatcher.UIThread
.InvokeAsync(async () =>
{
Logger.Debug($"Release mode: {IsReleaseMode}");
- var versionOptions = await value.GetAllVersionOptions();
+ var versionOptions = await SelectedPackage.GetAllVersionOptions();
AvailableVersions = IsReleaseMode
? new ObservableCollection(versionOptions.AvailableVersions)
: new ObservableCollection(versionOptions.AvailableBranches);
- SelectedVersion = AvailableVersions.First(x => !x.IsPrerelease);
+ SelectedVersion = AvailableVersions?.FirstOrDefault(x => !x.IsPrerelease);
+ if (SelectedVersion is null)
+ return;
+
ReleaseNotes = SelectedVersion.ReleaseNotesMarkdown;
Logger.Debug($"Loaded release notes for {ReleaseNotes}");
if (!IsReleaseMode)
{
- var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList();
+ var commits = (
+ await SelectedPackage.GetAllCommits(SelectedVersion.TagName)
+ )?.ToList();
if (commits is null || commits.Count == 0)
return;
@@ -408,6 +434,29 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
.SafeFireAndForget();
}
+ partial void OnSelectedPackageChanged(BasePackage? value)
+ {
+ IsLoading = true;
+ ReleaseNotes = string.Empty;
+ AvailableVersions?.Clear();
+ AvailableCommits?.Clear();
+
+ if (value == null)
+ return;
+
+ AvailableVersionTypes = SelectedPackage.ShouldIgnoreReleases
+ ? PackageVersionType.Commit
+ : PackageVersionType.GithubRelease | PackageVersionType.Commit;
+ IsReleaseMode = !SelectedPackage.ShouldIgnoreReleases;
+ SelectedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod;
+ SelectedTorchVersion = SelectedPackage.GetRecommendedTorchVersion();
+ SelectedVersionType = SelectedPackage.ShouldIgnoreReleases
+ ? PackageVersionType.Commit
+ : PackageVersionType.GithubRelease;
+
+ OnSelectedVersionTypeChanged(SelectedVersionType);
+ }
+
partial void OnInstallNameChanged(string? value)
{
ShowDuplicateWarning = settingsManager.Settings.InstalledPackages.Any(
@@ -418,7 +467,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
partial void OnSelectedVersionChanged(PackageVersion? value)
{
ReleaseNotes = value?.ReleaseNotesMarkdown ?? string.Empty;
- if (value == null)
+ if (value == null || Design.IsDesignMode)
return;
SelectedCommit = null;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs
index a4d2e4a6..303d4462 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs
@@ -17,6 +17,8 @@ using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(LaunchOptionsDialog))]
+[ManagedService]
+[Transient]
public partial class LaunchOptionsViewModel : ContentDialogViewModelBase
{
private readonly ILogger logger;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
index b4d4a823..155e2297 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.IO;
using System.Linq;
@@ -7,25 +8,29 @@ using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Avalonia.Languages;
+using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.PackageModification;
using StabilityMatrix.Core.Models.Packages;
-using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
-public partial class OneClickInstallViewModel : ViewModelBase
+[ManagedService]
+[Transient]
+public partial class OneClickInstallViewModel : ContentDialogViewModelBase
{
private readonly ISettingsManager settingsManager;
private readonly IPackageFactory packageFactory;
private readonly IPrerequisiteHelper prerequisiteHelper;
private readonly ILogger logger;
private readonly IPyRunner pyRunner;
- private readonly ISharedFolders sharedFolders;
+ private readonly INavigationService navigationService;
private const string DefaultPackageName = "stable-diffusion-webui";
[ObservableProperty]
@@ -43,6 +48,9 @@ public partial class OneClickInstallViewModel : ViewModelBase
[ObservableProperty]
private bool isIndeterminate;
+ [ObservableProperty]
+ private bool showIncompatiblePackages;
+
[ObservableProperty]
private ObservableCollection allPackages;
@@ -53,6 +61,8 @@ public partial class OneClickInstallViewModel : ViewModelBase
[NotifyPropertyChangedFor(nameof(IsProgressBarVisible))]
private int oneClickInstallProgress;
+ private bool isInferenceInstall;
+
public bool IsProgressBarVisible => OneClickInstallProgress > 0 || IsIndeterminate;
public OneClickInstallViewModel(
@@ -61,7 +71,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
IPrerequisiteHelper prerequisiteHelper,
ILogger logger,
IPyRunner pyRunner,
- ISharedFolders sharedFolders
+ INavigationService navigationService
)
{
this.settingsManager = settingsManager;
@@ -69,13 +79,21 @@ public partial class OneClickInstallViewModel : ViewModelBase
this.prerequisiteHelper = prerequisiteHelper;
this.logger = logger;
this.pyRunner = pyRunner;
- this.sharedFolders = sharedFolders;
+ this.navigationService = navigationService;
HeaderText = Resources.Text_WelcomeToStabilityMatrix;
SubHeaderText = Resources.Text_OneClickInstaller_SubHeader;
ShowInstallButton = true;
+
+ var filteredPackages = this.packageFactory
+ .GetAllAvailablePackages()
+ .Where(p => p is { OfferInOneClickInstaller: true, IsCompatible: true })
+ .ToList();
+
AllPackages = new ObservableCollection(
- this.packageFactory.GetAllAvailablePackages().Where(p => p.OfferInOneClickInstaller)
+ filteredPackages.Any()
+ ? filteredPackages
+ : this.packageFactory.GetAllAvailablePackages()
);
SelectedPackage = AllPackages[0];
}
@@ -95,39 +113,34 @@ public partial class OneClickInstallViewModel : ViewModelBase
return Task.CompletedTask;
}
- private async Task DoInstall()
+ [RelayCommand]
+ private async Task InstallComfyForInference()
{
- HeaderText = $"{Resources.Label_Installing} {SelectedPackage.DisplayName}";
-
- var progressHandler = new Progress(progress =>
- {
- SubHeaderText = $"{progress.Title} {progress.Percentage:N0}%";
-
- IsIndeterminate = progress.IsIndeterminate;
- OneClickInstallProgress = Convert.ToInt32(progress.Percentage);
- });
-
- await prerequisiteHelper.InstallAllIfNecessary(progressHandler);
-
- SubHeaderText = Resources.Progress_InstallingPrerequisites;
- IsIndeterminate = true;
- if (!PyRunner.PipInstalled)
+ var comfyPackage = AllPackages.FirstOrDefault(x => x is ComfyUI);
+ if (comfyPackage != null)
{
- await pyRunner.SetupPip();
+ SelectedPackage = comfyPackage;
+ isInferenceInstall = true;
+ await InstallCommand.ExecuteAsync(null);
}
+ }
- if (!PyRunner.VenvInstalled)
+ private async Task DoInstall()
+ {
+ var steps = new List
{
- await pyRunner.InstallPackage("virtualenv");
- }
- IsIndeterminate = false;
-
- var libraryDir = settingsManager.LibraryDir;
+ new SetPackageInstallingStep(settingsManager, SelectedPackage.Name),
+ new SetupPrerequisitesStep(prerequisiteHelper, pyRunner)
+ };
// get latest version & download & install
- var installLocation = Path.Combine(libraryDir, "Packages", SelectedPackage.Name);
+ var installLocation = Path.Combine(
+ settingsManager.LibraryDir,
+ "Packages",
+ SelectedPackage.Name
+ );
- var downloadVersion = new DownloadPackageVersionOptions();
+ var downloadVersion = new DownloadPackageVersionOptions { IsLatest = true };
var installedVersion = new InstalledPackageVersion();
var versionOptions = await SelectedPackage.GetAllVersionOptions();
@@ -139,16 +152,39 @@ public partial class OneClickInstallViewModel : ViewModelBase
else
{
downloadVersion.BranchName = await SelectedPackage.GetLatestVersion();
+ downloadVersion.CommitHash =
+ (await SelectedPackage.GetAllCommits(downloadVersion.BranchName))
+ ?.FirstOrDefault()
+ ?.Sha ?? string.Empty;
+
installedVersion.InstalledBranch = downloadVersion.BranchName;
+ installedVersion.InstalledCommitSha = downloadVersion.CommitHash;
}
var torchVersion = SelectedPackage.GetRecommendedTorchVersion();
+ var recommendedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod;
- await DownloadPackage(installLocation, downloadVersion);
- await InstallPackage(installLocation, torchVersion);
+ var downloadStep = new DownloadPackageVersionStep(
+ SelectedPackage,
+ installLocation,
+ downloadVersion
+ );
+ steps.Add(downloadStep);
- var recommendedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod;
- await SelectedPackage.SetupModelFolders(installLocation, recommendedSharedFolderMethod);
+ var installStep = new InstallPackageStep(
+ SelectedPackage,
+ torchVersion,
+ downloadVersion,
+ installLocation
+ );
+ steps.Add(installStep);
+
+ var setupModelFoldersStep = new SetupModelFoldersStep(
+ SelectedPackage,
+ recommendedSharedFolderMethod,
+ installLocation
+ );
+ steps.Add(setupModelFoldersStep);
var installedPackage = new InstalledPackage
{
@@ -162,59 +198,47 @@ public partial class OneClickInstallViewModel : ViewModelBase
PreferredTorchVersion = torchVersion,
PreferredSharedFolderMethod = recommendedSharedFolderMethod
};
- await using var st = settingsManager.BeginTransaction();
- st.Settings.InstalledPackages.Add(installedPackage);
- st.Settings.ActiveInstalledPackageId = installedPackage.Id;
- EventManager.Instance.OnInstalledPackagesChanged();
- HeaderText = Resources.Progress_InstallationComplete;
- SubSubHeaderText = string.Empty;
- OneClickInstallProgress = 100;
+ var addInstalledPackageStep = new AddInstalledPackageStep(
+ settingsManager,
+ installedPackage
+ );
+ steps.Add(addInstalledPackageStep);
- for (var i = 0; i < 3; i++)
+ var runner = new PackageModificationRunner
+ {
+ ShowDialogOnStart = true,
+ HideCloseButton = true,
+ };
+ EventManager.Instance.OnPackageInstallProgressAdded(runner);
+ await runner.ExecuteSteps(steps);
+
+ EventManager.Instance.OnInstalledPackagesChanged();
+ HeaderText = $"{SelectedPackage.DisplayName} installed successfully";
+ for (var i = 3; i > 0; i--)
{
- SubHeaderText = $"{Resources.Text_ProceedingToLaunchPage} ({i + 1}s)";
+ SubHeaderText = $"{Resources.Text_ProceedingToLaunchPage} ({i}s)";
await Task.Delay(1000);
}
// should close dialog
EventManager.Instance.OnOneClickInstallFinished(false);
- }
-
- private async Task DownloadPackage(
- string installLocation,
- DownloadPackageVersionOptions versionOptions
- )
- {
- SubHeaderText = Resources.Progress_DownloadingPackage;
-
- var progress = new Progress(progress =>
+ if (isInferenceInstall)
{
- IsIndeterminate = progress.IsIndeterminate;
- OneClickInstallProgress = Convert.ToInt32(progress.Percentage);
- EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress);
- });
-
- await SelectedPackage.DownloadPackage(installLocation, versionOptions, progress);
- SubHeaderText = Resources.Progress_DownloadComplete;
- OneClickInstallProgress = 100;
+ navigationService.NavigateTo();
+ }
}
- private async Task InstallPackage(string installLocation, TorchVersion torchVersion)
+ partial void OnShowIncompatiblePackagesChanged(bool value)
{
- var progress = new Progress(progress =>
- {
- SubHeaderText = Resources.Progress_InstallingPackageRequirements;
- IsIndeterminate = progress.IsIndeterminate;
- OneClickInstallProgress = Convert.ToInt32(progress.Percentage);
- EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress);
- });
+ var filteredPackages = packageFactory
+ .GetAllAvailablePackages()
+ .Where(p => p.OfferInOneClickInstaller && (ShowIncompatiblePackages || p.IsCompatible))
+ .ToList();
- await SelectedPackage.InstallPackage(
- installLocation,
- torchVersion,
- progress,
- (output) => SubSubHeaderText = output.Text
+ AllPackages = new ObservableCollection(
+ filteredPackages.Any() ? filteredPackages : packageFactory.GetAllAvailablePackages()
);
+ SelectedPackage = AllPackages[0];
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs
index 915c2a2f..06ff94de 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs
@@ -23,6 +23,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(PackageImportDialog))]
+[ManagedService]
+[Transient]
public partial class PackageImportViewModel : ContentDialogViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs
new file mode 100644
index 00000000..7259df54
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesItemViewModel.cs
@@ -0,0 +1,129 @@
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Avalonia.Controls;
+using CommunityToolkit.Mvvm.ComponentModel;
+using Semver;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Python;
+
+namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
+
+public partial class PythonPackagesItemViewModel : ViewModelBase
+{
+ [ObservableProperty]
+ private PipPackageInfo package;
+
+ [ObservableProperty]
+ private string? selectedVersion;
+
+ [ObservableProperty]
+ private IReadOnlyList? availableVersions;
+
+ [ObservableProperty]
+ private PipShowResult? pipShowResult;
+
+ [ObservableProperty]
+ private bool isLoading;
+
+ ///
+ /// True if selected version is newer than the installed version
+ ///
+ [ObservableProperty]
+ private bool canUpgrade;
+
+ ///
+ /// True if selected version is older than the installed version
+ ///
+ [ObservableProperty]
+ private bool canDowngrade;
+
+ partial void OnSelectedVersionChanged(string? value)
+ {
+ if (
+ value is null
+ || Package.Version == value
+ || !SemVersion.TryParse(Package.Version, out var currentSemver)
+ || !SemVersion.TryParse(value, out var selectedSemver)
+ )
+ {
+ CanUpgrade = false;
+ CanDowngrade = false;
+ return;
+ }
+
+ var precedence = selectedSemver.ComparePrecedenceTo(currentSemver);
+
+ CanUpgrade = precedence > 0;
+ CanDowngrade = precedence < 0;
+ }
+
+ ///
+ /// Return the known index URL for a package, currently this is torch, torchvision and torchaudio
+ ///
+ public static string? GetKnownIndexUrl(string packageName, string version)
+ {
+ var torchPackages = new[] { "torch", "torchvision", "torchaudio" };
+ if (torchPackages.Contains(packageName) && version.Contains('+'))
+ {
+ // Get the metadata for the current version (everything after the +)
+ var indexName = version.Split('+', 2).Last();
+
+ var indexUrl = $"https://download.pytorch.org/whl/{indexName}";
+ return indexUrl;
+ }
+
+ return null;
+ }
+
+ ///
+ /// Loads the pip show result if not already loaded
+ ///
+ public async Task LoadExtraInfo(DirectoryPath venvPath)
+ {
+ if (PipShowResult is not null)
+ {
+ return;
+ }
+
+ IsLoading = true;
+
+ try
+ {
+ if (Design.IsDesignMode)
+ {
+ await LoadExtraInfoDesignMode();
+ }
+ else
+ {
+ await using var venvRunner = new PyVenvRunner(venvPath);
+
+ PipShowResult = await venvRunner.PipShow(Package.Name);
+
+ // Attempt to get known index url
+ var indexUrl = GetKnownIndexUrl(Package.Name, Package.Version);
+
+ if (await venvRunner.PipIndex(Package.Name, indexUrl) is { } pipIndexResult)
+ {
+ AvailableVersions = pipIndexResult.AvailableVersions;
+ SelectedVersion = Package.Version;
+ }
+ }
+ }
+ finally
+ {
+ IsLoading = false;
+ }
+ }
+
+ private async Task LoadExtraInfoDesignMode()
+ {
+ await using var _ = new MinimumDelay(200, 300);
+
+ PipShowResult = new PipShowResult { Name = Package.Name, Version = Package.Version };
+ AvailableVersions = new[] { Package.Version, "1.2.0", "1.1.0", "1.0.0" };
+ SelectedVersion = Package.Version;
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesViewModel.cs
new file mode 100644
index 00000000..b7d00f8e
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PythonPackagesViewModel.cs
@@ -0,0 +1,313 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reactive.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using AsyncAwaitBestPractices;
+using Avalonia.Controls;
+using Avalonia.Controls.Primitives;
+using Avalonia.Threading;
+using CommunityToolkit.Mvvm.ComponentModel;
+using CommunityToolkit.Mvvm.Input;
+using DynamicData;
+using DynamicData.Binding;
+using FluentAvalonia.UI.Controls;
+using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Avalonia.Languages;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Avalonia.Views.Dialogs;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.PackageModification;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+
+namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
+
+[View(typeof(PythonPackagesDialog))]
+[ManagedService]
+[Transient]
+public partial class PythonPackagesViewModel : ContentDialogViewModelBase
+{
+ public DirectoryPath? VenvPath { get; set; }
+
+ [ObservableProperty]
+ private bool isLoading;
+
+ private readonly SourceCache packageSource = new(p => p.Name);
+
+ public IObservableCollection Packages { get; } =
+ new ObservableCollectionExtended();
+
+ [ObservableProperty]
+ private PythonPackagesItemViewModel? selectedPackage;
+
+ public PythonPackagesViewModel()
+ {
+ packageSource
+ .Connect()
+ .DeferUntilLoaded()
+ .Transform(p => new PythonPackagesItemViewModel { Package = p })
+ .SortBy(vm => vm.Package.Name)
+ .Bind(Packages)
+ .Subscribe();
+ }
+
+ private async Task Refresh()
+ {
+ if (VenvPath is null)
+ return;
+
+ IsLoading = true;
+
+ try
+ {
+ if (Design.IsDesignMode)
+ {
+ await Task.Delay(250);
+ }
+ else
+ {
+ await using var venvRunner = new PyVenvRunner(VenvPath);
+
+ var packages = await venvRunner.PipList();
+ packageSource.EditDiff(packages);
+ }
+ }
+ finally
+ {
+ IsLoading = false;
+ }
+ }
+
+ [RelayCommand]
+ private async Task RefreshBackground()
+ {
+ if (VenvPath is null)
+ return;
+
+ await using var venvRunner = new PyVenvRunner(VenvPath);
+
+ var packages = await venvRunner.PipList();
+
+ Dispatcher.UIThread.Post(() =>
+ {
+ // Backup selected package
+ var currentPackageName = SelectedPackage?.Package.Name;
+
+ packageSource.EditDiff(packages);
+
+ // Restore selected package
+ SelectedPackage = Packages.FirstOrDefault(p => p.Package.Name == currentPackageName);
+ });
+ }
+
+ ///
+ /// Load the selected package's show info if not already loaded
+ ///
+ partial void OnSelectedPackageChanged(PythonPackagesItemViewModel? value)
+ {
+ if (value is null)
+ {
+ return;
+ }
+
+ if (value.PipShowResult is null)
+ {
+ value.LoadExtraInfo(VenvPath!).SafeFireAndForget();
+ }
+ }
+
+ ///
+ public override Task OnLoadedAsync()
+ {
+ return Refresh();
+ }
+
+ public void AddPackages(params PipPackageInfo[] packages)
+ {
+ packageSource.AddOrUpdate(packages);
+ }
+
+ [RelayCommand]
+ private Task ModifySelectedPackage(PythonPackagesItemViewModel? item)
+ {
+ return item?.SelectedVersion != null
+ ? UpgradePackageVersion(
+ item.Package.Name,
+ item.SelectedVersion,
+ PythonPackagesItemViewModel.GetKnownIndexUrl(
+ item.Package.Name,
+ item.SelectedVersion
+ ),
+ isDowngrade: item.CanDowngrade
+ )
+ : Task.CompletedTask;
+ }
+
+ private async Task UpgradePackageVersion(
+ string packageName,
+ string version,
+ string? extraIndexUrl = null,
+ bool isDowngrade = false
+ )
+ {
+ if (VenvPath is null || SelectedPackage?.Package is not { } package)
+ return;
+
+ // Confirmation dialog
+ var dialog = DialogHelper.CreateMarkdownDialog(
+ isDowngrade
+ ? $"Downgrade **{package.Name}** to **{version}**?"
+ : $"Upgrade **{package.Name}** to **{version}**?",
+ Resources.Label_ConfirmQuestion
+ );
+
+ dialog.PrimaryButtonText = isDowngrade
+ ? Resources.Action_Downgrade
+ : Resources.Action_Upgrade;
+ dialog.IsPrimaryButtonEnabled = true;
+ dialog.DefaultButton = ContentDialogButton.Primary;
+ dialog.CloseButtonText = Resources.Action_Cancel;
+
+ if (await dialog.ShowAsync() is not ContentDialogResult.Primary)
+ {
+ return;
+ }
+
+ var args = new ProcessArgsBuilder("install", $"{packageName}=={version}");
+
+ if (extraIndexUrl != null)
+ {
+ args = args.AddArg(("--extra-index-url", extraIndexUrl));
+ }
+
+ var steps = new List
+ {
+ new PipStep
+ {
+ VenvDirectory = VenvPath,
+ WorkingDirectory = VenvPath.Parent,
+ Args = args
+ }
+ };
+
+ var runner = new PackageModificationRunner
+ {
+ ShowDialogOnStart = true,
+ ModificationCompleteMessage = isDowngrade
+ ? $"Downgraded Python Package '{packageName}' to {version}"
+ : $"Upgraded Python Package '{packageName}' to {version}"
+ };
+ EventManager.Instance.OnPackageInstallProgressAdded(runner);
+ await runner.ExecuteSteps(steps);
+
+ // Refresh
+ RefreshBackground().SafeFireAndForget();
+ }
+
+ [RelayCommand]
+ private async Task InstallPackage()
+ {
+ if (VenvPath is null)
+ return;
+
+ // Dialog
+ var fields = new TextBoxField[]
+ {
+ new() { Label = "Package Name", InnerLeftText = "pip install" }
+ };
+
+ var dialog = DialogHelper.CreateTextEntryDialog("Install Package", "", fields);
+ var result = await dialog.ShowAsync();
+
+ if (result != ContentDialogResult.Primary || fields[0].Text is not { } packageName)
+ {
+ return;
+ }
+
+ var steps = new List
+ {
+ new PipStep
+ {
+ VenvDirectory = VenvPath,
+ WorkingDirectory = VenvPath.Parent,
+ Args = new[] { "install", packageName }
+ }
+ };
+
+ var runner = new PackageModificationRunner
+ {
+ ShowDialogOnStart = true,
+ ModificationCompleteMessage = $"Installed Python Package '{packageName}'"
+ };
+ EventManager.Instance.OnPackageInstallProgressAdded(runner);
+ await runner.ExecuteSteps(steps);
+
+ // Refresh
+ RefreshBackground().SafeFireAndForget();
+ }
+
+ [RelayCommand]
+ private async Task UninstallSelectedPackage()
+ {
+ if (VenvPath is null || SelectedPackage?.Package is not { } package)
+ return;
+
+ // Confirmation dialog
+ var dialog = DialogHelper.CreateMarkdownDialog(
+ $"This will uninstall the package '{package.Name}'",
+ Resources.Label_ConfirmQuestion
+ );
+ dialog.PrimaryButtonText = Resources.Action_Uninstall;
+ dialog.IsPrimaryButtonEnabled = true;
+ dialog.DefaultButton = ContentDialogButton.Primary;
+ dialog.CloseButtonText = Resources.Action_Cancel;
+
+ if (await dialog.ShowAsync() is not ContentDialogResult.Primary)
+ {
+ return;
+ }
+
+ var steps = new List
+ {
+ new PipStep
+ {
+ VenvDirectory = VenvPath,
+ WorkingDirectory = VenvPath.Parent,
+ Args = new[] { "uninstall", "--yes", package.Name }
+ }
+ };
+
+ var runner = new PackageModificationRunner
+ {
+ ShowDialogOnStart = true,
+ ModificationCompleteMessage = $"Uninstalled Python Package '{package.Name}'"
+ };
+ EventManager.Instance.OnPackageInstallProgressAdded(runner);
+ await runner.ExecuteSteps(steps);
+
+ // Refresh
+ RefreshBackground().SafeFireAndForget();
+ }
+
+ public BetterContentDialog GetDialog()
+ {
+ return new BetterContentDialog
+ {
+ CloseOnClickOutside = true,
+ MinDialogWidth = 800,
+ MaxDialogWidth = 1500,
+ FullSizeDesired = true,
+ ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled,
+ Title = Resources.Label_PythonPackages,
+ Content = new PythonPackagesDialog { DataContext = this },
+ CloseButtonText = Resources.Action_Close,
+ DefaultButton = ContentDialogButton.Close
+ };
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs
index 15fa1e48..9c20bdf0 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs
@@ -19,6 +19,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(SelectDataDirectoryDialog))]
+[ManagedService]
+[Transient]
public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
index 77eac510..c3f10931 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
@@ -1,19 +1,19 @@
-using System;
-using System.Collections.Generic;
+using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
-using System.Threading.Tasks;
using Avalonia.Media.Imaging;
-using Avalonia.Platform;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
+[ManagedService]
+[Transient]
public partial class SelectModelVersionViewModel : ContentDialogViewModelBase
{
private readonly ISettingsManager settingsManager;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs
index 8710295a..8ee671a1 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs
@@ -22,6 +22,8 @@ using StabilityMatrix.Core.Updater;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(UpdateDialog))]
+[ManagedService]
+[Singleton]
public partial class UpdateViewModel : ContentDialogViewModelBase
{
private readonly ISettingsManager settingsManager;
diff --git a/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs
index 25a6faed..54463c21 100644
--- a/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/FirstLaunchSetupViewModel.cs
@@ -11,6 +11,8 @@ using StabilityMatrix.Core.Helper;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(FirstLaunchSetupWindow))]
+[ManagedService]
+[Singleton]
public partial class FirstLaunchSetupViewModel : ViewModelBase
{
[ObservableProperty]
@@ -20,14 +22,17 @@ public partial class FirstLaunchSetupViewModel : ViewModelBase
private string gpuInfoText = string.Empty;
[ObservableProperty]
- private RefreshBadgeViewModel checkHardwareBadge = new()
- {
- WorkingToolTipText = "We're checking some hardware specifications to determine compatibility.",
- SuccessToolTipText = "Everything looks good!",
- FailToolTipText = "We recommend a GPU with CUDA support for the best experience. " +
- "You can continue without one, but some packages may not work, and inference may be slower.",
- FailColorBrush = ThemeColors.ThemeYellow,
- };
+ private RefreshBadgeViewModel checkHardwareBadge =
+ new()
+ {
+ WorkingToolTipText =
+ "We're checking some hardware specifications to determine compatibility.",
+ SuccessToolTipText = "Everything looks good!",
+ FailToolTipText =
+ "We recommend a GPU with CUDA support for the best experience. "
+ + "You can continue without one, but some packages may not work, and inference may be slower.",
+ FailColorBrush = ThemeColors.ThemeYellow,
+ };
public FirstLaunchSetupViewModel()
{
@@ -43,14 +48,16 @@ public partial class FirstLaunchSetupViewModel : ViewModelBase
gpuInfo = await Task.Run(() => HardwareHelper.IterGpuInfo().ToArray());
}
// First Nvidia GPU
- var activeGpu = gpuInfo.FirstOrDefault(gpu => gpu.Name?.ToLowerInvariant().Contains("nvidia") ?? false);
+ var activeGpu = gpuInfo.FirstOrDefault(
+ gpu => gpu.Name?.ToLowerInvariant().Contains("nvidia") ?? false
+ );
var isNvidia = activeGpu is not null;
// Otherwise first GPU
activeGpu ??= gpuInfo.FirstOrDefault();
GpuInfoText = activeGpu is null
? "No GPU detected"
: $"{activeGpu.Name} ({Size.FormatBytes(activeGpu.MemoryBytes)})";
-
+
return isNvidia;
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs
index d18d22bd..12a1d0e9 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/BatchSizeCardViewModel.cs
@@ -6,6 +6,8 @@ using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(BatchSizeCard))]
+[ManagedService]
+[Transient]
public partial class BatchSizeCardViewModel : LoadableViewModelBase
{
[ObservableProperty]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs
index 059975d3..91b09866 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/FreeUCardViewModel.cs
@@ -7,6 +7,8 @@ using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(FreeUCard))]
+[ManagedService]
+[Transient]
public partial class FreeUCardViewModel : LoadableViewModelBase
{
[ObservableProperty]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs
index eab2a345..65f9fa0a 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/IImageGalleryComponent.cs
@@ -1,6 +1,4 @@
-using System.Collections.Generic;
-using System.Linq;
-using System.Threading.Tasks;
+using System.Linq;
using StabilityMatrix.Avalonia.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs
index e8097834..7502f496 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageFolderCardViewModel.cs
@@ -4,7 +4,6 @@ using System.Reactive.Linq;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using AsyncImageLoader;
-using Avalonia;
using Avalonia.Controls.Notifications;
using Avalonia.Platform.Storage;
using Avalonia.Threading;
@@ -26,6 +25,7 @@ using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Settings;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Services;
using SortDirection = DynamicData.Binding.SortDirection;
@@ -33,6 +33,8 @@ using SortDirection = DynamicData.Binding.SortDirection;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ImageFolderCard))]
+[ManagedService]
+[Transient]
public partial class ImageFolderCardViewModel : ViewModelBase
{
private readonly ILogger logger;
@@ -43,6 +45,9 @@ public partial class ImageFolderCardViewModel : ViewModelBase
[ObservableProperty]
private string? searchQuery;
+ [ObservableProperty]
+ private Size imageSize = new(150, 190);
+
///
/// Collection of local image files
///
@@ -61,20 +66,28 @@ public partial class ImageFolderCardViewModel : ViewModelBase
this.settingsManager = settingsManager;
this.notificationService = notificationService;
- var predicate = this.WhenPropertyChanged(vm => vm.SearchQuery)
+ var searcher = new ImageSearcher();
+
+ // Observable predicate from SearchQuery changes
+ var searchPredicate = this.WhenPropertyChanged(vm => vm.SearchQuery)
.Throttle(TimeSpan.FromMilliseconds(50))!
- .Select, Func>(
- p => file => SearchPredicate(file, p.Value)
- )
+ .Select(property => searcher.GetPredicate(property.Value))
.AsObservable();
imageIndexService.InferenceImages.ItemsSource
.Connect()
.DeferUntilLoaded()
- .Filter(predicate)
+ .Filter(searchPredicate)
.SortBy(file => file.LastModifiedAt, SortDirection.Descending)
.Bind(LocalImages)
.Subscribe();
+
+ settingsManager.RelayPropertyFor(
+ this,
+ vm => vm.ImageSize,
+ settings => settings.InferenceImageSize,
+ delay: TimeSpan.FromMilliseconds(250)
+ );
}
private static bool SearchPredicate(LocalImageFile file, string? query)
@@ -116,24 +129,49 @@ public partial class ImageFolderCardViewModel : ViewModelBase
public override async Task OnLoadedAsync()
{
await base.OnLoadedAsync();
-
+ ImageSize = settingsManager.Settings.InferenceImageSize;
imageIndexService.RefreshIndexForAllCollections().SafeFireAndForget();
}
+ ///
+ /// Gets the image path if it exists, returns null.
+ /// If the image path is resolved but the file doesn't exist, it will be removed from the index.
+ ///
+ private FilePath? GetImagePathIfExists(LocalImageFile item)
+ {
+ var imageFile = new FilePath(item.AbsolutePath);
+
+ if (!imageFile.Exists)
+ {
+ // Remove from index
+ imageIndexService.InferenceImages.Remove(item);
+
+ // Invalidate cache
+ if (ImageLoader.AsyncImageLoader is FallbackRamCachedWebImageLoader loader)
+ {
+ loader.RemoveAllNamesFromCache(imageFile.Name);
+ }
+
+ return null;
+ }
+
+ return imageFile;
+ }
+
///
/// Handles image clicks to show preview
///
[RelayCommand]
private async Task OnImageClick(LocalImageFile item)
{
- if (item.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath)
+ if (GetImagePathIfExists(item) is not { } imageFile)
{
return;
}
var currentIndex = LocalImages.IndexOf(item);
- var image = new ImageSource(new FilePath(imagePath));
+ var image = new ImageSource(imageFile);
// Preload
await image.GetBitmapAsync();
@@ -156,14 +194,12 @@ public partial class ImageFolderCardViewModel : ViewModelBase
if (newIndex >= 0 && newIndex < LocalImages.Count)
{
var newImage = LocalImages[newIndex];
- var newImageSource = new ImageSource(
- new FilePath(newImage.GetFullPath(settingsManager.ImagesDirectory))
- );
+ var newImageSource = new ImageSource(newImage.AbsolutePath);
// Preload
await newImageSource.GetBitmapAsync();
- var oldImageSource = sender.ImageSource;
+ // var oldImageSource = sender.ImageSource;
sender.ImageSource = newImageSource;
sender.LocalImageFile = newImage;
@@ -185,13 +221,12 @@ public partial class ImageFolderCardViewModel : ViewModelBase
[RelayCommand]
private async Task OnImageDelete(LocalImageFile? item)
{
- if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath)
+ if (item is null || GetImagePathIfExists(item) is not { } imageFile)
{
return;
}
// Delete the file
- var imageFile = new FilePath(imagePath);
var result = await notificationService.TryAsync(imageFile.DeleteAsync());
if (!result.IsSuccessful)
@@ -215,14 +250,14 @@ public partial class ImageFolderCardViewModel : ViewModelBase
[RelayCommand]
private async Task OnImageCopy(LocalImageFile? item)
{
- if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath)
+ if (item is null || GetImagePathIfExists(item) is not { } imageFile)
{
return;
}
var clipboard = App.Clipboard;
- await clipboard.SetFileDataObjectAsync(imagePath);
+ await clipboard.SetFileDataObjectAsync(imageFile.FullPath);
}
///
@@ -231,12 +266,12 @@ public partial class ImageFolderCardViewModel : ViewModelBase
[RelayCommand]
private async Task OnImageOpen(LocalImageFile? item)
{
- if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } imagePath)
+ if (item is null || GetImagePathIfExists(item) is not { } imageFile)
{
return;
}
- await ProcessRunner.OpenFileBrowser(imagePath);
+ await ProcessRunner.OpenFileBrowser(imageFile);
}
///
@@ -248,13 +283,11 @@ public partial class ImageFolderCardViewModel : ViewModelBase
bool includeMetadata = false
)
{
- if (item?.GetFullPath(settingsManager.ImagesDirectory) is not { } sourcePath)
+ if (item is null || GetImagePathIfExists(item) is not { } sourceFile)
{
return;
}
- var sourceFile = new FilePath(sourcePath);
-
var formatName = format.ToString();
var storageFile = await App.StorageProvider.SaveFilePickerAsync(
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs
index a8f5920c..93497d43 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs
@@ -23,6 +23,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ImageGalleryCard))]
+[ManagedService]
+[Transient]
public partial class ImageGalleryCardViewModel : ViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
index 9e56a16d..d2a196e6 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageUpscaleViewModel.cs
@@ -1,13 +1,10 @@
using System;
using System.Diagnostics.CodeAnalysis;
-using System.Drawing;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
-using Avalonia.Controls.Shapes;
-using Avalonia.Threading;
using DynamicData.Binding;
using NLog;
using StabilityMatrix.Avalonia.Extensions;
@@ -19,6 +16,7 @@ using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
+using StabilityMatrix.Core.Services;
using Path = System.IO.Path;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
@@ -27,6 +25,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceImageUpscaleView), persistent: true)]
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)]
+[ManagedService]
+[Transient]
public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -60,9 +60,10 @@ public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
public InferenceImageUpscaleViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
+ ISettingsManager settingsManager,
ServiceManager vmFactory
)
- : base(vmFactory, inferenceClientManager, notificationService)
+ : base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{
this.notificationService = notificationService;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
index 07124ac0..c06c438c 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
@@ -26,6 +26,8 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), persistent: true)]
+[ManagedService]
+[Transient]
public class InferenceTextToImageViewModel
: InferenceGenerationViewModelBase,
IParametersLoadableState
@@ -86,10 +88,11 @@ public class InferenceTextToImageViewModel
public InferenceTextToImageViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
+ ISettingsManager settingsManager,
ServiceManager vmFactory,
IModelIndexService modelIndexService
)
- : base(vmFactory, inferenceClientManager, notificationService)
+ : base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{
this.notificationService = notificationService;
this.modelIndexService = modelIndexService;
@@ -248,7 +251,7 @@ public class InferenceTextToImageViewModel
if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false })
{
var customVaeLoader = nodes.AddNamedNode(
- ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.FileName)
+ ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath)
);
builder.Connections.BaseVAE = customVaeLoader.Output;
@@ -381,22 +384,7 @@ public class InferenceTextToImageViewModel
Client = ClientManager.Client,
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
- Parameters = new GenerationParameters
- {
- Seed = (ulong)seed,
- Steps = SamplerCardViewModel.Steps,
- CfgScale = SamplerCardViewModel.CfgScale,
- Sampler = SamplerCardViewModel.SelectedSampler?.Name,
- ModelName = ModelCardViewModel.SelectedModelName,
- ModelHash = ModelCardViewModel
- .SelectedModel
- ?.Local
- ?.ConnectedModelInfo
- ?.Hashes
- .SHA256,
- PositivePrompt = PromptCardViewModel.PromptDocument.Text,
- NegativePrompt = PromptCardViewModel.NegativePromptDocument.Text
- },
+ Parameters = SaveStateToParameters(new GenerationParameters()),
Project = InferenceProjectDocument.FromLoadable(this),
// Only clear output images on the first batch
ClearOutputImages = i == 0
@@ -417,10 +405,9 @@ public class InferenceTextToImageViewModel
{
PromptCardViewModel.LoadStateFromParameters(parameters);
SamplerCardViewModel.LoadStateFromParameters(parameters);
+ ModelCardViewModel.LoadStateFromParameters(parameters);
SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed);
-
- ModelCardViewModel.LoadStateFromParameters(parameters);
}
///
@@ -428,11 +415,10 @@ public class InferenceTextToImageViewModel
{
parameters = PromptCardViewModel.SaveStateToParameters(parameters);
parameters = SamplerCardViewModel.SaveStateToParameters(parameters);
+ parameters = ModelCardViewModel.SaveStateToParameters(parameters);
parameters.Seed = (ulong)SeedCardViewModel.Seed;
- parameters = ModelCardViewModel.SaveStateToParameters(parameters);
-
return parameters;
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
index 2c28da8f..d69f5b80 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
@@ -12,6 +12,8 @@ using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ModelCard))]
+[ManagedService]
+[Transient]
public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState
{
[ObservableProperty]
@@ -29,9 +31,9 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
[ObservableProperty]
private bool isVaeSelectionEnabled;
- public string? SelectedModelName => SelectedModel?.FileName;
+ public string? SelectedModelName => SelectedModel?.RelativePath;
- public string? SelectedVaeName => SelectedVae?.FileName;
+ public string? SelectedVaeName => SelectedVae?.RelativePath;
public IInferenceClientManager ClientManager { get; }
@@ -60,11 +62,11 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
SelectedModel = model.SelectedModelName is null
? null
- : ClientManager.Models.FirstOrDefault(x => x.FileName == model.SelectedModelName);
+ : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName);
SelectedVae = model.SelectedVaeName is null
? HybridModelFile.Default
- : ClientManager.VaeModels.FirstOrDefault(x => x.FileName == model.SelectedVaeName);
+ : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
}
internal class ModelCardModel
@@ -99,7 +101,7 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
else
{
// Name matches
- model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName));
+ model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(
m => m.ShortDisplayName.StartsWith(paramsModelName)
);
@@ -114,6 +116,10 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
///
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
- return parameters with { ModelName = SelectedModel?.FileName };
+ return parameters with
+ {
+ ModelName = SelectedModel?.FileName,
+ ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256
+ };
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
index a7b76ad0..26e8acf8 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
@@ -1,15 +1,9 @@
-using System;
-using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Diagnostics;
-using System.Linq;
-using System.Text;
+using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using AvaloniaEdit;
using AvaloniaEdit.Document;
-using AvaloniaEdit.Editing;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Controls;
@@ -20,7 +14,6 @@ using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions;
-using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Services;
@@ -28,6 +21,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))]
+[ManagedService]
+[Transient]
public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState
{
private readonly IModelIndexService modelIndexService;
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
index 1d70fa42..8d52ebe9 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
@@ -13,6 +13,8 @@ using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SamplerCard))]
+[ManagedService]
+[Transient]
public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState
{
[ObservableProperty]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs
index 6c52c4db..1f0a9fa4 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SeedCardViewModel.cs
@@ -10,6 +10,8 @@ using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SeedCard))]
+[ManagedService]
+[Transient]
public partial class SeedCardViewModel : LoadableViewModelBase
{
[ObservableProperty, NotifyPropertyChangedFor(nameof(RandomizeButtonToolTip))]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs
index ff175ccd..7639bc14 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs
@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Drawing;
-using System.Text.Json;
using System.Linq;
using Avalonia.Input;
using Avalonia.Media;
@@ -19,6 +18,8 @@ using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SelectImageCard))]
+[ManagedService]
+[Transient]
public partial class SelectImageCardViewModel : ViewModelBase, IDropTarget
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -88,7 +89,7 @@ public partial class SelectImageCardViewModel : ViewModelBase, IDropTarget
{
var current = ImageSource;
- ImageSource = new ImageSource(imageFile.GlobalFullPath);
+ ImageSource = new ImageSource(imageFile.AbsolutePath);
current?.Dispose();
});
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs
index 98a1f933..90a67a12 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SharpenCardViewModel.cs
@@ -7,6 +7,8 @@ using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SharpenCard))]
+[ManagedService]
+[Transient]
public partial class SharpenCardViewModel : LoadableViewModelBase
{
[Range(1, 31)]
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs
index 763054b2..49e7b810 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/StackCardViewModel.cs
@@ -8,20 +8,24 @@ using StabilityMatrix.Core.Extensions;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(StackCard))]
+[ManagedService]
+[Transient]
public class StackCardViewModel : StackViewModelBase
{
///
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel(state);
-
- if (model.Cards is null) return;
-
+
+ if (model.Cards is null)
+ return;
+
foreach (var (i, card) in model.Cards.Enumerate())
{
// Ignore if more than cards than we have
- if (i > Cards.Count - 1) break;
-
+ if (i > Cards.Count - 1)
+ break;
+
Cards[i].LoadStateFromJsonObject(card);
}
}
@@ -29,9 +33,8 @@ public class StackCardViewModel : StackViewModelBase
///
public override JsonObject SaveStateToJsonObject()
{
- return SerializeModel(new StackCardModel
- {
- Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList()
- });
+ return SerializeModel(
+ new StackCardModel { Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList() }
+ );
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs
index 2b4e3c2e..9eef7ad6 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs
@@ -11,28 +11,32 @@ using StabilityMatrix.Core.Extensions;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(StackExpander))]
+[ManagedService]
+[Transient]
public partial class StackExpanderViewModel : StackViewModelBase
{
[ObservableProperty]
[property: JsonIgnore]
private string? title;
-
- [ObservableProperty]
+
+ [ObservableProperty]
private bool isEnabled;
-
+
///
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel(state);
IsEnabled = model.IsEnabled;
-
- if (model.Cards is null) return;
-
+
+ if (model.Cards is null)
+ return;
+
foreach (var (i, card) in model.Cards.Enumerate())
{
// Ignore if more than cards than we have
- if (i > Cards.Count - 1) break;
-
+ if (i > Cards.Count - 1)
+ break;
+
Cards[i].LoadStateFromJsonObject(card);
}
}
@@ -40,10 +44,12 @@ public partial class StackExpanderViewModel : StackViewModelBase
///
public override JsonObject SaveStateToJsonObject()
{
- return SerializeModel(new StackExpanderModel
- {
- IsEnabled = IsEnabled,
- Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList()
- });
+ return SerializeModel(
+ new StackExpanderModel
+ {
+ IsEnabled = IsEnabled,
+ Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList()
+ }
+ );
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs
index ceef75c6..02ef8a83 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Inference/UpscalerCardViewModel.cs
@@ -20,6 +20,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(UpscalerCard))]
+[ManagedService]
+[Transient]
public partial class UpscalerCardViewModel : LoadableViewModelBase
{
private readonly INotificationService notificationService;
diff --git a/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
index 4b2ff874..33320aab 100644
--- a/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/InferenceViewModel.cs
@@ -9,7 +9,6 @@ using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
-using Avalonia.Controls.Shapes;
using Avalonia.Platform.Storage;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
@@ -23,7 +22,6 @@ using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Avalonia.Views;
-using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Extensions;
@@ -43,6 +41,7 @@ namespace StabilityMatrix.Avalonia.ViewModels;
[Preload]
[View(typeof(InferencePage))]
+[Singleton]
public partial class InferenceViewModel : PageViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -111,6 +110,10 @@ public partial class InferenceViewModel : PageViewModelBase
// Keep RunningPackage updated with the current package pair
EventManager.Instance.RunningPackageStatusChanged += OnRunningPackageStatusChanged;
+ // "Send to Inference"
+ EventManager.Instance.InferenceTextToImageRequested += OnInferenceTextToImageRequested;
+ EventManager.Instance.InferenceUpscaleRequested += OnInferenceUpscaleRequested;
+
MenuSaveAsCommand.WithConditionalNotificationErrorHandler(notificationService);
MenuOpenProjectCommand.WithConditionalNotificationErrorHandler(notificationService);
}
@@ -232,6 +235,16 @@ public partial class InferenceViewModel : PageViewModelBase
}
}
+ private void OnInferenceTextToImageRequested(object? sender, LocalImageFile e)
+ {
+ Dispatcher.UIThread.Post(() => AddTabFromImage(e).SafeFireAndForget());
+ }
+
+ private void OnInferenceUpscaleRequested(object? sender, LocalImageFile e)
+ {
+ Dispatcher.UIThread.Post(() => AddUpscalerTabFromImage(e).SafeFireAndForget());
+ }
+
///
/// Update the database with current tabs
///
@@ -593,6 +606,70 @@ public partial class InferenceViewModel : PageViewModelBase
await SyncTabStatesWithDatabase();
}
+ private async Task AddTabFromImage(LocalImageFile imageFile)
+ {
+ var metadata = imageFile.ReadMetadata();
+ InferenceTabViewModelBase? vm = null;
+
+ if (!string.IsNullOrWhiteSpace(metadata.SMProject))
+ {
+ var document = JsonSerializer.Deserialize(metadata.SMProject);
+ if (document is null)
+ {
+ throw new ApplicationException(
+ "MenuOpenProject: Deserialize project file returned null"
+ );
+ }
+
+ if (document.State is null)
+ {
+ throw new ApplicationException("Project file does not have 'State' key");
+ }
+
+ document.VerifyVersion();
+ var textToImage = vmFactory.Get();
+ textToImage.LoadStateFromJsonObject(document.State);
+ vm = textToImage;
+ }
+ else if (!string.IsNullOrWhiteSpace(metadata.Parameters))
+ {
+ if (GenerationParameters.TryParse(metadata.Parameters, out var generationParameters))
+ {
+ var textToImageViewModel = vmFactory.Get();
+ textToImageViewModel.LoadStateFromParameters(generationParameters);
+ vm = textToImageViewModel;
+ }
+ }
+
+ if (vm == null)
+ {
+ notificationService.Show(
+ "Unable to load project from image",
+ "No image metadata found",
+ NotificationType.Error
+ );
+ return;
+ }
+
+ Tabs.Add(vm);
+
+ SelectedTab = vm;
+
+ await SyncTabStatesWithDatabase();
+ }
+
+ private async Task AddUpscalerTabFromImage(LocalImageFile imageFile)
+ {
+ var upscaleVm = vmFactory.Get();
+ upscaleVm.IsUpscaleEnabled = true;
+ upscaleVm.SelectImageCardViewModel.ImageSource = new ImageSource(imageFile.AbsolutePath);
+
+ Tabs.Add(upscaleVm);
+ SelectedTab = upscaleVm;
+
+ await SyncTabStatesWithDatabase();
+ }
+
///
/// Menu "Open Project" command.
///
diff --git a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
index 87c7ea3f..b9ed79ab 100644
--- a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
@@ -40,6 +40,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(LaunchPageView))]
+[Singleton]
public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyncDisposable
{
private readonly ILogger logger;
diff --git a/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs
index b165aef0..87ebb0f5 100644
--- a/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs
@@ -2,14 +2,12 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
-using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
-using AvaloniaEdit.Utils;
using CommunityToolkit.Mvvm.ComponentModel;
using FluentAvalonia.UI.Controls;
using Microsoft.Extensions.Logging;
@@ -35,6 +33,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(NewCheckpointsPage))]
+[Singleton]
public partial class NewCheckpointsPageViewModel : PageViewModelBase
{
private readonly ILogger logger;
@@ -44,12 +43,17 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
private readonly ServiceManager dialogFactory;
private readonly INotificationService notificationService;
public override string Title => "Checkpoint Manager";
- public override IconSource IconSource => new SymbolIconSource
- {Symbol = Symbol.Cellular5g, IsFilled = true};
+ public override IconSource IconSource =>
+ new SymbolIconSource { Symbol = Symbol.Cellular5g, IsFilled = true };
- public NewCheckpointsPageViewModel(ILogger logger,
- ISettingsManager settingsManager, ILiteDbContext liteDbContext, ICivitApi civitApi,
- ServiceManager dialogFactory, INotificationService notificationService)
+ public NewCheckpointsPageViewModel(
+ ILogger logger,
+ ISettingsManager settingsManager,
+ ILiteDbContext liteDbContext,
+ ICivitApi civitApi,
+ ServiceManager dialogFactory,
+ INotificationService notificationService
+ )
{
this.logger = logger;
this.settingsManager = settingsManager;
@@ -63,23 +67,27 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
[NotifyPropertyChangedFor(nameof(ConnectedCheckpoints))]
[NotifyPropertyChangedFor(nameof(NonConnectedCheckpoints))]
private ObservableCollection allCheckpoints = new();
-
+
[ObservableProperty]
private ObservableCollection civitModels = new();
- public ObservableCollection ConnectedCheckpoints => new(
- AllCheckpoints.Where(x => x.IsConnectedModel)
- .OrderBy(x => x.ConnectedModel!.ModelName)
- .ThenBy(x => x.ModelType)
- .GroupBy(x => x.ConnectedModel!.ModelId)
- .Select(x => x.First()));
+ public ObservableCollection ConnectedCheckpoints =>
+ new(
+ AllCheckpoints
+ .Where(x => x.IsConnectedModel)
+ .OrderBy(x => x.ConnectedModel!.ModelName)
+ .ThenBy(x => x.ModelType)
+ .GroupBy(x => x.ConnectedModel!.ModelId)
+ .Select(x => x.First())
+ );
- public ObservableCollection NonConnectedCheckpoints => new(
- AllCheckpoints.Where(x => !x.IsConnectedModel).OrderBy(x => x.ModelType));
+ public ObservableCollection NonConnectedCheckpoints =>
+ new(AllCheckpoints.Where(x => !x.IsConnectedModel).OrderBy(x => x.ModelType));
public override async Task OnLoadedAsync()
{
- if (Design.IsDesignMode) return;
+ if (Design.IsDesignMode)
+ return;
var files = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory);
AllCheckpoints = new ObservableCollection(files);
@@ -89,17 +97,17 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
{
CommaSeparatedModelIds = string.Join(',', connectedModelIds)
};
-
+
// See if query is cached
var cachedQuery = await liteDbContext.CivitModelQueryCache
.IncludeAll()
.FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest));
-
+
// If cached, update model cards
if (cachedQuery is not null)
{
CivitModels = new ObservableCollection(cachedQuery.Items);
-
+
// Start remote query (background mode)
// Skip when last query was less than 2 min ago
var timeSinceCache = DateTimeOffset.UtcNow - cachedQuery.InsertedAt;
@@ -113,24 +121,34 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
await CivitQuery(modelRequest);
}
}
-
+
public async Task ShowVersionDialog(int modelId)
{
var model = CivitModels.FirstOrDefault(m => m.Id == modelId);
if (model == null)
{
- notificationService.Show(new Notification("Model has no versions available",
- "This model has no versions available for download", NotificationType.Warning));
+ notificationService.Show(
+ new Notification(
+ "Model has no versions available",
+ "This model has no versions available for download",
+ NotificationType.Warning
+ )
+ );
return;
}
var versions = model.ModelVersions;
if (versions is null || versions.Count == 0)
{
- notificationService.Show(new Notification("Model has no versions available",
- "This model has no versions available for download", NotificationType.Warning));
+ notificationService.Show(
+ new Notification(
+ "Model has no versions available",
+ "This model has no versions available for download",
+ NotificationType.Warning
+ )
+ );
return;
}
-
+
var dialog = new BetterContentDialog
{
Title = model.Name,
@@ -139,19 +157,21 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
IsFooterVisible = false,
MaxDialogWidth = 750,
};
-
+
var viewModel = dialogFactory.Get();
viewModel.Dialog = dialog;
- viewModel.Versions = versions.Select(version =>
- new ModelVersionViewModel(
- settingsManager.Settings.InstalledModelHashes ?? new HashSet(), version))
+ viewModel.Versions = versions
+ .Select(
+ version =>
+ new ModelVersionViewModel(
+ settingsManager.Settings.InstalledModelHashes ?? new HashSet(),
+ version
+ )
+ )
.ToImmutableArray();
viewModel.SelectedVersionViewModel = viewModel.Versions[0];
-
- dialog.Content = new SelectModelVersionDialog
- {
- DataContext = viewModel
- };
+
+ dialog.Content = new SelectModelVersionDialog { DataContext = viewModel };
var result = await dialog.ShowAsync();
@@ -171,8 +191,10 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
var modelResponse = await civitApi.GetModels(request);
var models = modelResponse.Items;
// Filter out unknown model types and archived/taken-down models
- models = models.Where(m => m.Type.ConvertTo() > 0)
- .Where(m => m.Mode == null).ToList();
+ models = models
+ .Where(m => m.Type.ConvertTo() > 0)
+ .Where(m => m.Mode == null)
+ .ToList();
// Database update calls will invoke `OnModelsUpdated`
// Add to database
@@ -186,7 +208,8 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
Request = request,
Items = models,
Metadata = modelResponse.Metadata
- });
+ }
+ );
if (cacheNew)
{
@@ -195,26 +218,42 @@ public partial class NewCheckpointsPageViewModel : PageViewModelBase
}
catch (OperationCanceledException)
{
- notificationService.Show(new Notification("Request to CivitAI timed out",
- "Could not check for checkpoint updates. Please try again later."));
+ notificationService.Show(
+ new Notification(
+ "Request to CivitAI timed out",
+ "Could not check for checkpoint updates. Please try again later."
+ )
+ );
logger.LogWarning($"CivitAI query timed out ({request})");
}
catch (HttpRequestException e)
{
- notificationService.Show(new Notification("CivitAI can't be reached right now",
- "Could not check for checkpoint updates. Please try again later."));
+ notificationService.Show(
+ new Notification(
+ "CivitAI can't be reached right now",
+ "Could not check for checkpoint updates. Please try again later."
+ )
+ );
logger.LogWarning(e, $"CivitAI query HttpRequestException ({request})");
}
catch (ApiException e)
{
- notificationService.Show(new Notification("CivitAI can't be reached right now",
- "Could not check for checkpoint updates. Please try again later."));
+ notificationService.Show(
+ new Notification(
+ "CivitAI can't be reached right now",
+ "Could not check for checkpoint updates. Please try again later."
+ )
+ );
logger.LogWarning(e, $"CivitAI query ApiException ({request})");
}
catch (Exception e)
{
- notificationService.Show(new Notification("CivitAI can't be reached right now",
- $"Unknown exception during CivitAI query: {e.GetType().Name}"));
+ notificationService.Show(
+ new Notification(
+ "CivitAI can't be reached right now",
+ $"Unknown exception during CivitAI query: {e.GetType().Name}"
+ )
+ );
logger.LogError(e, $"CivitAI query unknown exception ({request})");
}
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs
new file mode 100644
index 00000000..30ce380e
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs
@@ -0,0 +1,18 @@
+using CommunityToolkit.Mvvm.ComponentModel;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Models.Database;
+
+namespace StabilityMatrix.Avalonia.ViewModels.OutputsPage;
+
+public partial class OutputImageViewModel : ViewModelBase
+{
+ public LocalImageFile ImageFile { get; }
+
+ [ObservableProperty]
+ private bool isSelected;
+
+ public OutputImageViewModel(LocalImageFile imageFile)
+ {
+ ImageFile = imageFile;
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs
new file mode 100644
index 00000000..8352f6d5
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs
@@ -0,0 +1,531 @@
+using System;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Diagnostics;
+using System.IO;
+using System.Linq;
+using System.Reactive.Linq;
+using System.Threading.Tasks;
+using AsyncAwaitBestPractices;
+using AsyncImageLoader;
+using Avalonia;
+using Avalonia.Controls;
+using Avalonia.Media;
+using Avalonia.Threading;
+using CommunityToolkit.Mvvm.ComponentModel;
+using DynamicData;
+using DynamicData.Binding;
+using FluentAvalonia.UI.Controls;
+using Microsoft.Extensions.Logging;
+using Nito.Disposables.Internals;
+using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Avalonia.Extensions;
+using StabilityMatrix.Avalonia.Helpers;
+using StabilityMatrix.Avalonia.Languages;
+using StabilityMatrix.Avalonia.Models;
+using StabilityMatrix.Avalonia.Services;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Avalonia.ViewModels.Dialogs;
+using StabilityMatrix.Avalonia.ViewModels.OutputsPage;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.Factory;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.Database;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Services;
+using Size = StabilityMatrix.Core.Models.Settings.Size;
+using Symbol = FluentIcons.Common.Symbol;
+using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
+
+namespace StabilityMatrix.Avalonia.ViewModels;
+
+[View(typeof(Views.OutputsPage))]
+[Singleton]
+public partial class OutputsPageViewModel : PageViewModelBase
+{
+ private readonly ISettingsManager settingsManager;
+ private readonly IPackageFactory packageFactory;
+ private readonly INotificationService notificationService;
+ private readonly INavigationService navigationService;
+ private readonly ILogger logger;
+ public override string Title => Resources.Label_OutputsPageTitle;
+
+ public override IconSource IconSource =>
+ new SymbolIconSource { Symbol = Symbol.Grid, IsFilled = true };
+
+ public SourceCache OutputsCache { get; } =
+ new(file => file.AbsolutePath);
+
+ public IObservableCollection Outputs { get; set; } =
+ new ObservableCollectionExtended();
+
+ public IEnumerable OutputTypes { get; } = Enum.GetValues();
+
+ [ObservableProperty]
+ private ObservableCollection categories;
+
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(CanShowOutputTypes))]
+ private PackageOutputCategory selectedCategory;
+
+ [ObservableProperty]
+ private SharedOutputType selectedOutputType;
+
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(NumImagesSelected))]
+ private int numItemsSelected;
+
+ [ObservableProperty]
+ private string searchQuery;
+
+ [ObservableProperty]
+ private Size imageSize = new(300, 300);
+
+ [ObservableProperty]
+ private bool isConsolidating;
+
+ public bool CanShowOutputTypes =>
+ SelectedCategory?.Name?.Equals("Shared Output Folder") ?? false;
+
+ public string NumImagesSelected =>
+ NumItemsSelected == 1
+ ? Resources.Label_OneImageSelected
+ : string.Format(Resources.Label_NumImagesSelected, NumItemsSelected);
+
+ public OutputsPageViewModel(
+ ISettingsManager settingsManager,
+ IPackageFactory packageFactory,
+ INotificationService notificationService,
+ INavigationService navigationService,
+ ILogger logger
+ )
+ {
+ this.settingsManager = settingsManager;
+ this.packageFactory = packageFactory;
+ this.notificationService = notificationService;
+ this.navigationService = navigationService;
+ this.logger = logger;
+
+ var searcher = new ImageSearcher();
+
+ // Observable predicate from SearchQuery changes
+ var searchPredicate = this.WhenPropertyChanged(vm => vm.SearchQuery)
+ .Throttle(TimeSpan.FromMilliseconds(50))!
+ .Select(property => searcher.GetPredicate(property.Value))
+ .AsObservable();
+
+ OutputsCache
+ .Connect()
+ .DeferUntilLoaded()
+ .Filter(searchPredicate)
+ .Transform(file => new OutputImageViewModel(file))
+ .SortBy(vm => vm.ImageFile.CreatedAt, SortDirection.Descending)
+ .Bind(Outputs)
+ .WhenPropertyChanged(p => p.IsSelected)
+ .Subscribe(_ =>
+ {
+ NumItemsSelected = Outputs.Count(o => o.IsSelected);
+ });
+
+ settingsManager.RelayPropertyFor(
+ this,
+ vm => vm.ImageSize,
+ settings => settings.OutputsImageSize,
+ delay: TimeSpan.FromMilliseconds(250)
+ );
+ }
+
+ public override void OnLoaded()
+ {
+ if (Design.IsDesignMode)
+ return;
+
+ if (!settingsManager.IsLibraryDirSet)
+ return;
+
+ Directory.CreateDirectory(settingsManager.ImagesDirectory);
+
+ var packageCategories = settingsManager.Settings.InstalledPackages
+ .Where(x => !x.UseSharedOutputFolder)
+ .Select(packageFactory.GetPackagePair)
+ .WhereNotNull()
+ .Where(
+ p =>
+ p.BasePackage.SharedOutputFolders != null
+ && p.BasePackage.SharedOutputFolders.Any()
+ )
+ .Select(
+ pair =>
+ new PackageOutputCategory
+ {
+ Path = Path.Combine(
+ pair.InstalledPackage.FullPath!,
+ pair.BasePackage.OutputFolderName
+ ),
+ Name = pair.InstalledPackage.DisplayName ?? ""
+ }
+ )
+ .ToList();
+
+ packageCategories.Insert(
+ 0,
+ new PackageOutputCategory
+ {
+ Path = settingsManager.ImagesDirectory,
+ Name = "Shared Output Folder"
+ }
+ );
+
+ packageCategories.Insert(
+ 1,
+ new PackageOutputCategory
+ {
+ Path = settingsManager.ImagesInferenceDirectory,
+ Name = "Inference"
+ }
+ );
+
+ Categories = new ObservableCollection(packageCategories);
+ SelectedCategory = Categories.First();
+ SelectedOutputType = SharedOutputType.All;
+ SearchQuery = string.Empty;
+ ImageSize = settingsManager.Settings.OutputsImageSize;
+
+ var path =
+ CanShowOutputTypes && SelectedOutputType != SharedOutputType.All
+ ? Path.Combine(SelectedCategory.Path, SelectedOutputType.ToString())
+ : SelectedCategory.Path;
+ GetOutputs(path);
+ }
+
+ partial void OnSelectedCategoryChanged(
+ PackageOutputCategory? oldValue,
+ PackageOutputCategory? newValue
+ )
+ {
+ if (oldValue == newValue || newValue == null)
+ return;
+
+ var path =
+ CanShowOutputTypes && SelectedOutputType != SharedOutputType.All
+ ? Path.Combine(newValue.Path, SelectedOutputType.ToString())
+ : SelectedCategory.Path;
+ GetOutputs(path);
+ }
+
+ partial void OnSelectedOutputTypeChanged(SharedOutputType oldValue, SharedOutputType newValue)
+ {
+ if (oldValue == newValue)
+ return;
+
+ var path =
+ newValue == SharedOutputType.All
+ ? SelectedCategory.Path
+ : Path.Combine(SelectedCategory.Path, newValue.ToString());
+ GetOutputs(path);
+ }
+
+ public Task OnImageClick(OutputImageViewModel item)
+ {
+ // Select image if we're in "select mode"
+ if (NumItemsSelected > 0)
+ {
+ item.IsSelected = !item.IsSelected;
+ }
+ else
+ {
+ return ShowImageDialog(item);
+ }
+
+ return Task.CompletedTask;
+ }
+
+ public async Task ShowImageDialog(OutputImageViewModel item)
+ {
+ var currentIndex = Outputs.IndexOf(item);
+
+ var image = new ImageSource(new FilePath(item.ImageFile.AbsolutePath));
+
+ // Preload
+ await image.GetBitmapAsync();
+
+ var vm = new ImageViewerViewModel { ImageSource = image, LocalImageFile = item.ImageFile };
+
+ using var onNext = Observable
+ .FromEventPattern(
+ vm,
+ nameof(ImageViewerViewModel.NavigationRequested)
+ )
+ .Subscribe(ctx =>
+ {
+ Dispatcher.UIThread
+ .InvokeAsync(async () =>
+ {
+ var sender = (ImageViewerViewModel)ctx.Sender!;
+ var newIndex = currentIndex + (ctx.EventArgs.IsNext ? 1 : -1);
+
+ if (newIndex >= 0 && newIndex < Outputs.Count)
+ {
+ var newImage = Outputs[newIndex];
+ var newImageSource = new ImageSource(
+ new FilePath(newImage.ImageFile.AbsolutePath)
+ );
+
+ // Preload
+ await newImageSource.GetBitmapAsync();
+
+ sender.ImageSource = newImageSource;
+ sender.LocalImageFile = newImage.ImageFile;
+
+ currentIndex = newIndex;
+ }
+ })
+ .SafeFireAndForget();
+ });
+
+ await vm.GetDialog().ShowAsync();
+ }
+
+ public Task CopyImage(string imagePath)
+ {
+ var clipboard = App.Clipboard;
+ return clipboard.SetFileDataObjectAsync(imagePath);
+ }
+
+ public Task OpenImage(string imagePath) => ProcessRunner.OpenFileBrowser(imagePath);
+
+ public async Task DeleteImage(OutputImageViewModel? item)
+ {
+ if (item is null)
+ return;
+
+ var confirmationDialog = new BetterContentDialog
+ {
+ Title = "Are you sure you want to delete this image?",
+ Content = "This action cannot be undone.",
+ PrimaryButtonText = Resources.Action_Delete,
+ SecondaryButtonText = Resources.Action_Cancel,
+ DefaultButton = ContentDialogButton.Primary,
+ IsSecondaryButtonEnabled = true,
+ };
+ var dialogResult = await confirmationDialog.ShowAsync();
+ if (dialogResult != ContentDialogResult.Primary)
+ return;
+
+ // Delete the file
+ var imageFile = new FilePath(item.ImageFile.AbsolutePath);
+ var result = await notificationService.TryAsync(imageFile.DeleteAsync());
+
+ if (!result.IsSuccessful)
+ {
+ return;
+ }
+
+ OutputsCache.Remove(item.ImageFile);
+
+ // Invalidate cache
+ if (ImageLoader.AsyncImageLoader is FallbackRamCachedWebImageLoader loader)
+ {
+ loader.RemoveAllNamesFromCache(imageFile.Name);
+ }
+ }
+
+ public void SendToTextToImage(OutputImageViewModel vm)
+ {
+ navigationService.NavigateTo();
+ EventManager.Instance.OnInferenceTextToImageRequested(vm.ImageFile);
+ }
+
+ public void SendToUpscale(OutputImageViewModel vm)
+ {
+ navigationService.NavigateTo();
+ EventManager.Instance.OnInferenceUpscaleRequested(vm.ImageFile);
+ }
+
+ public void ClearSelection()
+ {
+ foreach (var output in Outputs)
+ {
+ output.IsSelected = false;
+ }
+ }
+
+ public void SelectAll()
+ {
+ foreach (var output in Outputs)
+ {
+ output.IsSelected = true;
+ }
+ }
+
+ public async Task DeleteAllSelected()
+ {
+ var confirmationDialog = new BetterContentDialog
+ {
+ Title = $"Are you sure you want to delete {NumItemsSelected} images?",
+ Content = "This action cannot be undone.",
+ PrimaryButtonText = Resources.Action_Delete,
+ SecondaryButtonText = Resources.Action_Cancel,
+ DefaultButton = ContentDialogButton.Primary,
+ IsSecondaryButtonEnabled = true,
+ };
+ var dialogResult = await confirmationDialog.ShowAsync();
+ if (dialogResult != ContentDialogResult.Primary)
+ return;
+
+ var selected = Outputs.Where(o => o.IsSelected).ToList();
+ Debug.Assert(selected.Count == NumItemsSelected);
+ foreach (var output in selected)
+ {
+ // Delete the file
+ var imageFile = new FilePath(output.ImageFile.AbsolutePath);
+ var result = await notificationService.TryAsync(imageFile.DeleteAsync());
+
+ if (!result.IsSuccessful)
+ {
+ continue;
+ }
+ OutputsCache.Remove(output.ImageFile);
+
+ // Invalidate cache
+ if (ImageLoader.AsyncImageLoader is FallbackRamCachedWebImageLoader loader)
+ {
+ loader.RemoveAllNamesFromCache(imageFile.Name);
+ }
+ }
+
+ NumItemsSelected = 0;
+ ClearSelection();
+ }
+
+ public async Task ConsolidateImages()
+ {
+ var stackPanel = new StackPanel();
+ stackPanel.Children.Add(
+ new TextBlock
+ {
+ Text = Resources.Label_ConsolidateExplanation,
+ TextWrapping = TextWrapping.Wrap,
+ Margin = new Thickness(0, 8, 0, 16)
+ }
+ );
+ foreach (var category in Categories)
+ {
+ if (category.Name == "Shared Output Folder")
+ {
+ continue;
+ }
+
+ stackPanel.Children.Add(
+ new CheckBox
+ {
+ Content = $"{category.Name} ({category.Path})",
+ IsChecked = true,
+ Margin = new Thickness(0, 8, 0, 0),
+ Tag = category.Path
+ }
+ );
+ }
+
+ var confirmationDialog = new BetterContentDialog
+ {
+ Title = Resources.Label_AreYouSure,
+ Content = stackPanel,
+ PrimaryButtonText = Resources.Action_Yes,
+ SecondaryButtonText = Resources.Action_Cancel,
+ DefaultButton = ContentDialogButton.Primary,
+ IsSecondaryButtonEnabled = true,
+ };
+
+ var dialogResult = await confirmationDialog.ShowAsync();
+ if (dialogResult != ContentDialogResult.Primary)
+ return;
+
+ IsConsolidating = true;
+
+ Directory.CreateDirectory(settingsManager.ConsolidatedImagesDirectory);
+
+ foreach (
+ var category in stackPanel.Children.OfType().Where(c => c.IsChecked == true)
+ )
+ {
+ if (
+ string.IsNullOrWhiteSpace(category.Tag?.ToString())
+ || !Directory.Exists(category.Tag?.ToString())
+ )
+ continue;
+
+ var directory = category.Tag.ToString();
+
+ foreach (
+ var path in Directory.EnumerateFiles(
+ directory,
+ "*.png",
+ SearchOption.AllDirectories
+ )
+ )
+ {
+ try
+ {
+ var file = new FilePath(path);
+ var newPath = settingsManager.ConsolidatedImagesDirectory + file.Name;
+ if (file.FullPath == newPath)
+ continue;
+
+ // ignore inference if not in inference directory
+ if (
+ file.FullPath.Contains(settingsManager.ImagesInferenceDirectory)
+ && directory != settingsManager.ImagesInferenceDirectory
+ )
+ {
+ continue;
+ }
+
+ await file.MoveToAsync(newPath);
+ }
+ catch (Exception e)
+ {
+ logger.LogError(e, "Error when consolidating: ");
+ }
+ }
+ }
+
+ OnLoaded();
+ IsConsolidating = false;
+ }
+
+ private void GetOutputs(string directory)
+ {
+ if (!settingsManager.IsLibraryDirSet)
+ return;
+
+ if (
+ !Directory.Exists(directory)
+ && (
+ SelectedCategory.Path != settingsManager.ImagesDirectory
+ || SelectedOutputType != SharedOutputType.All
+ )
+ )
+ {
+ Directory.CreateDirectory(directory);
+ return;
+ }
+
+ var files = Directory
+ .EnumerateFiles(directory, "*.png", SearchOption.AllDirectories)
+ .Select(file => LocalImageFile.FromPath(file))
+ .ToList();
+
+ if (files.Count == 0)
+ {
+ OutputsCache.Clear();
+ }
+ else
+ {
+ OutputsCache.EditDiff(files);
+ }
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
index d921aa08..467f45d7 100644
--- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
@@ -5,6 +5,7 @@ using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
using CommunityToolkit.Mvvm.ComponentModel;
+using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Avalonia.Animations;
@@ -14,6 +15,7 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.Views.Dialogs;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory;
@@ -26,6 +28,8 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.PackageManager;
+[ManagedService]
+[Transient]
public partial class PackageCardViewModel : ProgressViewModel
{
private readonly ILogger logger;
@@ -62,6 +66,15 @@ public partial class PackageCardViewModel : ProgressViewModel
[ObservableProperty]
private bool canUseConfigMethod;
+ [ObservableProperty]
+ private bool canUseSymlinkMethod;
+
+ [ObservableProperty]
+ private bool useSharedOutput;
+
+ [ObservableProperty]
+ private bool canUseSharedOutput;
+
public PackageCardViewModel(
ILogger logger,
IPackageFactory packageFactory,
@@ -103,6 +116,11 @@ public partial class PackageCardViewModel : ProgressViewModel
CanUseConfigMethod =
basePackage?.AvailableSharedFolderMethods.Contains(SharedFolderMethod.Configuration)
?? false;
+ CanUseSymlinkMethod =
+ basePackage?.AvailableSharedFolderMethods.Contains(SharedFolderMethod.Symlink)
+ ?? false;
+ UseSharedOutput = Package?.UseSharedOutputFolder ?? false;
+ CanUseSharedOutput = basePackage?.SharedOutputFolders != null;
}
}
@@ -243,7 +261,29 @@ public partial class PackageCardViewModel : ProgressViewModel
{
ModificationCompleteMessage = $"{packageName} Update Complete"
};
- var updatePackageStep = new UpdatePackageStep(settingsManager, Package, basePackage);
+
+ var versionOptions = new DownloadPackageVersionOptions { IsLatest = true };
+ if (Package.Version.IsReleaseMode)
+ {
+ versionOptions.VersionTag = await basePackage.GetLatestVersion();
+ }
+ else
+ {
+ var commits = await basePackage.GetAllCommits(Package.Version.InstalledBranch);
+ var latest = commits?.FirstOrDefault();
+ if (latest == null)
+ throw new Exception("Could not find latest commit");
+
+ versionOptions.BranchName = Package.Version.InstalledBranch;
+ versionOptions.CommitHash = latest.Sha;
+ }
+
+ var updatePackageStep = new UpdatePackageStep(
+ settingsManager,
+ Package,
+ versionOptions,
+ basePackage
+ );
var steps = new List { updatePackageStep };
EventManager.Instance.OnPackageInstallProgressAdded(runner);
@@ -335,6 +375,20 @@ public partial class PackageCardViewModel : ProgressViewModel
await ProcessRunner.OpenFolderBrowser(Package.FullPath);
}
+ [RelayCommand]
+ public async Task OpenPythonPackagesDialog()
+ {
+ if (Package is not { FullPath: not null })
+ return;
+
+ var vm = vmFactory.Get(vm =>
+ {
+ vm.VenvPath = new DirectoryPath(Package.FullPath, "venv");
+ });
+
+ await vm.GetDialog().ShowAsync();
+ }
+
private async Task HasUpdate()
{
if (Package == null || IsUnknownPackage || Design.IsDesignMode)
@@ -374,6 +428,33 @@ public partial class PackageCardViewModel : ProgressViewModel
public void ToggleSharedModelNone() => IsSharedModelDisabled = !IsSharedModelDisabled;
+ public void ToggleSharedOutput() => UseSharedOutput = !UseSharedOutput;
+
+ partial void OnUseSharedOutputChanged(bool value)
+ {
+ if (Package == null)
+ return;
+
+ if (value == Package.UseSharedOutputFolder)
+ return;
+
+ using var st = settingsManager.BeginTransaction();
+ Package.UseSharedOutputFolder = value;
+
+ var basePackage = packageFactory[Package.PackageName!];
+ if (basePackage == null)
+ return;
+
+ if (value)
+ {
+ basePackage.SetupOutputFolderLinks(Package.FullPath!);
+ }
+ else
+ {
+ basePackage.RemoveOutputFolderLinks(Package.FullPath!);
+ }
+ }
+
// fake radio button stuff
partial void OnIsSharedModelSymlinkChanged(bool oldValue, bool newValue)
{
diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
index 22c4dd10..5578a5f7 100644
--- a/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
@@ -7,9 +7,12 @@ using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
+using Avalonia.Controls.Primitives;
+using Avalonia.Threading;
using DynamicData;
using DynamicData.Binding;
using FluentAvalonia.UI.Controls;
+using Microsoft.Extensions.Logging;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
@@ -19,7 +22,6 @@ using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Avalonia.Views.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
-using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.PackageModification;
@@ -35,12 +37,13 @@ namespace StabilityMatrix.Avalonia.ViewModels;
///
[View(typeof(PackageManagerPage))]
+[Singleton]
public partial class PackageManagerViewModel : PageViewModelBase
{
private readonly ISettingsManager settingsManager;
- private readonly IPackageFactory packageFactory;
private readonly ServiceManager dialogFactory;
private readonly INotificationService notificationService;
+ private readonly ILogger logger;
public override string Title => "Packages";
public override IconSource IconSource =>
@@ -62,17 +65,19 @@ public partial class PackageManagerViewModel : PageViewModelBase
public IObservableCollection PackageCards { get; } =
new ObservableCollectionExtended();
+ private DispatcherTimer timer;
+
public PackageManagerViewModel(
ISettingsManager settingsManager,
- IPackageFactory packageFactory,
ServiceManager dialogFactory,
- INotificationService notificationService
+ INotificationService notificationService,
+ ILogger logger
)
{
this.settingsManager = settingsManager;
- this.packageFactory = packageFactory;
this.dialogFactory = dialogFactory;
this.notificationService = notificationService;
+ this.logger = logger;
EventManager.Instance.InstalledPackagesChanged += OnInstalledPackagesChanged;
@@ -93,6 +98,9 @@ public partial class PackageManagerViewModel : PageViewModelBase
)
.Bind(PackageCards)
.Subscribe();
+
+ timer = new DispatcherTimer { Interval = TimeSpan.FromMinutes(15), IsEnabled = true };
+ timer.Tick += async (_, _) => await CheckPackagesForUpdates();
}
public void SetPackages(IEnumerable packages)
@@ -117,22 +125,31 @@ public partial class PackageManagerViewModel : PageViewModelBase
var currentUnknown = await Task.Run(IndexUnknownPackages);
unknownInstalledPackages.Edit(s => s.Load(currentUnknown));
+
+ timer.Start();
+ }
+
+ public override void OnUnloaded()
+ {
+ timer.Stop();
+ base.OnUnloaded();
}
public async Task ShowInstallDialog(BasePackage? selectedPackage = null)
{
var viewModel = dialogFactory.Get();
- viewModel.AvailablePackages = packageFactory.GetAllAvailablePackages().ToImmutableArray();
viewModel.SelectedPackage = selectedPackage ?? viewModel.AvailablePackages[0];
var dialog = new BetterContentDialog
{
MaxDialogWidth = 900,
MinDialogWidth = 900,
+ FullSizeDesired = true,
DefaultButton = ContentDialogButton.Close,
IsPrimaryButtonEnabled = false,
IsSecondaryButtonEnabled = false,
IsFooterVisible = false,
+ ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled,
Content = new InstallerDialog { DataContext = viewModel }
};
@@ -157,6 +174,25 @@ public partial class PackageManagerViewModel : PageViewModelBase
}
}
+ private async Task CheckPackagesForUpdates()
+ {
+ foreach (var package in PackageCards)
+ {
+ try
+ {
+ await package.OnLoadedAsync();
+ }
+ catch (Exception e)
+ {
+ logger.LogError(
+ e,
+ "Failed to check for updates for {Package}",
+ package?.Package?.PackageName
+ );
+ }
+ }
+ }
+
private IEnumerable IndexUnknownPackages()
{
var packageDir = new DirectoryPath(settingsManager.LibraryDir).JoinDir("Packages");
diff --git a/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs
index dc410872..9cb0eb3d 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Progress/PackageInstallProgressItemViewModel.cs
@@ -17,7 +17,10 @@ public class PackageInstallProgressItemViewModel : ProgressItemViewModelBase
private readonly IPackageModificationRunner packageModificationRunner;
private BetterContentDialog? dialog;
- public PackageInstallProgressItemViewModel(IPackageModificationRunner packageModificationRunner)
+ public PackageInstallProgressItemViewModel(
+ IPackageModificationRunner packageModificationRunner,
+ bool hideCloseButton = false
+ )
{
this.packageModificationRunner = packageModificationRunner;
Id = packageModificationRunner.Id;
@@ -25,6 +28,7 @@ public class PackageInstallProgressItemViewModel : ProgressItemViewModelBase
Progress.Value = packageModificationRunner.CurrentProgress.Percentage;
Progress.Text = packageModificationRunner.ConsoleOutput.LastOrDefault();
Progress.IsIndeterminate = packageModificationRunner.CurrentProgress.IsIndeterminate;
+ Progress.HideCloseButton = hideCloseButton;
Progress.Console.StartUpdates();
diff --git a/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs
index 51b88020..445f0e0c 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs
@@ -22,6 +22,8 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels.Progress;
[View(typeof(ProgressManagerPage))]
+[ManagedService]
+[Singleton]
public partial class ProgressManagerViewModel : PageViewModelBase
{
private readonly INotificationService notificationService;
@@ -120,19 +122,22 @@ public partial class ProgressManagerViewModel : PageViewModelBase
}
}
- private async Task AddPackageInstall(IPackageModificationRunner packageModificationRunner)
+ private Task AddPackageInstall(IPackageModificationRunner packageModificationRunner)
{
if (ProgressItems.Any(vm => vm.Id == packageModificationRunner.Id))
{
- return;
+ return Task.CompletedTask;
}
- var vm = new PackageInstallProgressItemViewModel(packageModificationRunner);
+ var vm = new PackageInstallProgressItemViewModel(
+ packageModificationRunner,
+ packageModificationRunner.HideCloseButton
+ );
ProgressItems.Add(vm);
- if (packageModificationRunner.ShowDialogOnStart)
- {
- await vm.ShowProgressDialog();
- }
+
+ return packageModificationRunner.ShowDialogOnStart
+ ? vm.ShowProgressDialog()
+ : Task.CompletedTask;
}
private void ShowFailedNotification(string title, string message)
diff --git a/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs
index 10ee1194..b624f1e6 100644
--- a/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/RefreshBadgeViewModel.cs
@@ -16,10 +16,12 @@ namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(RefreshBadge))]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+[ManagedService]
+[Transient]
public partial class RefreshBadgeViewModel : ViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
-
+
public string WorkingToolTipText { get; set; } = "Loading...";
public string SuccessToolTipText { get; set; } = "Success";
public string InactiveToolTipText { get; set; } = "";
@@ -32,7 +34,7 @@ public partial class RefreshBadgeViewModel : ViewModelBase
public IBrush SuccessColorBrush { get; set; } = ThemeColors.ThemeGreen;
public IBrush InactiveColorBrush { get; set; } = ThemeColors.ThemeYellow;
public IBrush FailColorBrush { get; set; } = ThemeColors.ThemeYellow;
-
+
public Func>? RefreshFunc { get; set; }
[ObservableProperty]
@@ -41,7 +43,7 @@ public partial class RefreshBadgeViewModel : ViewModelBase
[NotifyPropertyChangedFor(nameof(CurrentToolTip))]
[NotifyPropertyChangedFor(nameof(Icon))]
private ProgressState state;
-
+
public bool IsWorking => State == ProgressState.Working;
/*public ControlAppearance Appearance => State switch
@@ -51,36 +53,40 @@ public partial class RefreshBadgeViewModel : ViewModelBase
ProgressState.Failed => ControlAppearance.Danger,
_ => ControlAppearance.Secondary
};*/
-
- public IBrush ColorBrush => State switch
- {
- ProgressState.Success => SuccessColorBrush,
- ProgressState.Inactive => InactiveColorBrush,
- ProgressState.Failed => FailColorBrush,
- _ => Brushes.Gray
- };
- public string CurrentToolTip => State switch
- {
- ProgressState.Working => WorkingToolTipText,
- ProgressState.Success => SuccessToolTipText,
- ProgressState.Inactive => InactiveToolTipText,
- ProgressState.Failed => FailToolTipText,
- _ => ""
- };
-
- public Symbol Icon => State switch
- {
- ProgressState.Success => SuccessIcon,
- ProgressState.Failed => FailIcon,
- _ => InactiveIcon
- };
+ public IBrush ColorBrush =>
+ State switch
+ {
+ ProgressState.Success => SuccessColorBrush,
+ ProgressState.Inactive => InactiveColorBrush,
+ ProgressState.Failed => FailColorBrush,
+ _ => Brushes.Gray
+ };
+
+ public string CurrentToolTip =>
+ State switch
+ {
+ ProgressState.Working => WorkingToolTipText,
+ ProgressState.Success => SuccessToolTipText,
+ ProgressState.Inactive => InactiveToolTipText,
+ ProgressState.Failed => FailToolTipText,
+ _ => ""
+ };
+
+ public Symbol Icon =>
+ State switch
+ {
+ ProgressState.Success => SuccessIcon,
+ ProgressState.Failed => FailIcon,
+ _ => InactiveIcon
+ };
[RelayCommand]
private async Task Refresh()
{
Logger.Info("Running refresh command...");
- if (RefreshFunc == null) return;
+ if (RefreshFunc == null)
+ return;
State = ProgressState.Working;
try
diff --git a/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs
index cec925fb..80d5c7ba 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs
@@ -1,12 +1,9 @@
-using CommunityToolkit.Mvvm.ComponentModel;
-using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.Views.Settings;
using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.ViewModels.Settings;
[View(typeof(InferenceSettingsPage))]
-public partial class InferenceSettingsViewModel : ViewModelBase
-{
-
-}
+[Singleton]
+public partial class InferenceSettingsViewModel : ViewModelBase { }
diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
index eee148fa..74dc145b 100644
--- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
@@ -3,15 +3,16 @@ using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.ComponentModel;
+using System.ComponentModel.DataAnnotations;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
+using System.Reactive.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
-using AsyncAwaitBestPractices;
using Avalonia;
using Avalonia.Controls.Notifications;
using Avalonia.Controls.Primitives;
@@ -21,6 +22,7 @@ using Avalonia.Styling;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
+using DynamicData.Binding;
using FluentAvalonia.UI.Controls;
using NLog;
using SkiaSharp;
@@ -29,6 +31,7 @@ using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
+using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Models.TagCompletion;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
@@ -49,6 +52,7 @@ using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(SettingsPage))]
+[Singleton]
public partial class SettingsViewModel : PageViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -107,6 +111,25 @@ public partial class SettingsViewModel : PageViewModelBase
[ObservableProperty]
private bool isCompletionRemoveUnderscoresEnabled = true;
+ [ObservableProperty]
+ [CustomValidation(typeof(SettingsViewModel), nameof(ValidateOutputImageFileNameFormat))]
+ private string? outputImageFileNameFormat;
+
+ [ObservableProperty]
+ private string? outputImageFileNameFormatSample;
+
+ public IEnumerable OutputImageFileNameFormatVars =>
+ FileNameFormatProvider
+ .GetSample()
+ .Substitutions.Select(
+ kv =>
+ new FileNameFormatVar
+ {
+ Variable = $"{{{kv.Key}}}",
+ Example = kv.Value.Invoke()
+ }
+ );
+
[ObservableProperty]
private bool isImageViewerPixelGridEnabled = true;
@@ -201,6 +224,39 @@ public partial class SettingsViewModel : PageViewModelBase
true
);
+ this.WhenPropertyChanged(vm => vm.OutputImageFileNameFormat)
+ .Throttle(TimeSpan.FromMilliseconds(50))
+ .Subscribe(formatProperty =>
+ {
+ var provider = FileNameFormatProvider.GetSample();
+ var template = formatProperty.Value ?? string.Empty;
+
+ if (
+ !string.IsNullOrEmpty(template)
+ && provider.Validate(template) == ValidationResult.Success
+ )
+ {
+ var format = FileNameFormat.Parse(template, provider);
+ OutputImageFileNameFormatSample = format.GetFileName() + ".png";
+ }
+ else
+ {
+ // Use default format if empty
+ var defaultFormat = FileNameFormat.Parse(
+ FileNameFormat.DefaultTemplate,
+ provider
+ );
+ OutputImageFileNameFormatSample = defaultFormat.GetFileName() + ".png";
+ }
+ });
+
+ settingsManager.RelayPropertyFor(
+ this,
+ vm => vm.OutputImageFileNameFormat,
+ settings => settings.InferenceOutputImageFileNameFormat,
+ true
+ );
+
settingsManager.RelayPropertyFor(
this,
vm => vm.IsImageViewerPixelGridEnabled,
@@ -225,6 +281,14 @@ public partial class SettingsViewModel : PageViewModelBase
UpdateAvailableTagCompletionCsvs();
}
+ public static ValidationResult ValidateOutputImageFileNameFormat(
+ string? format,
+ ValidationContext context
+ )
+ {
+ return FileNameFormatProvider.GetSample().Validate(format ?? string.Empty);
+ }
+
partial void OnSelectedThemeChanged(string? value)
{
// In case design / tests
diff --git a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml
index 7ad45f0c..e14edf4a 100644
--- a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml
+++ b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml
@@ -21,7 +21,7 @@
@@ -66,6 +66,7 @@
-
+
-
+
-
+
diff --git a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml.cs b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml.cs
index f61dae4c..34e0887a 100644
--- a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia.Markup.Xaml;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views;
+[Singleton]
public partial class CheckpointBrowserPage : UserControlBase
{
public CheckpointBrowserPage()
diff --git a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
index 55f759ec..1c8d8d0b 100644
--- a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
+++ b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
@@ -30,7 +30,7 @@
Height="18"
Margin="4,0,0,0"
Padding="3"
- Width="40">
+ Width="48">
+ Margin="4">
-
+
@@ -65,9 +66,9 @@
-
+
@@ -121,15 +122,17 @@
-
+
-
+
@@ -213,8 +215,9 @@
@@ -268,7 +271,8 @@
ItemTemplate="{StaticResource CheckpointFileDataTemplate}"
ItemsSource="{Binding DisplayedCheckpointFiles}">
-
+
@@ -280,7 +284,7 @@
IsVisible="{Binding !CheckpointFiles.Count}"/>
@@ -445,6 +449,8 @@
diff --git a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml.cs b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml.cs
index 3cccca86..ee5c2946 100644
--- a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml.cs
@@ -1,17 +1,29 @@
-using Avalonia.Controls;
+using System;
+using System.Linq;
+using Avalonia.Controls;
using Avalonia.Input;
+using Avalonia.Interactivity;
using Avalonia.Markup.Xaml;
+using Avalonia.VisualTree;
+using DynamicData.Binding;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Avalonia.ViewModels;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Helper;
using CheckpointFolder = StabilityMatrix.Avalonia.ViewModels.CheckpointManager.CheckpointFolder;
namespace StabilityMatrix.Avalonia.Views;
+[Singleton]
public partial class CheckpointsPage : UserControlBase
{
+ private ItemsControl? repeater;
+ private IDisposable? subscription;
+
public CheckpointsPage()
{
InitializeComponent();
-
+
AddHandler(DragDrop.DragEnterEvent, OnDragEnter);
AddHandler(DragDrop.DragLeaveEvent, OnDragExit);
AddHandler(DragDrop.DropEvent, OnDrop);
@@ -21,6 +33,34 @@ public partial class CheckpointsPage : UserControlBase
{
AvaloniaXamlLoader.Load(this);
}
+
+ protected override void OnDataContextChanged(EventArgs e)
+ {
+ base.OnDataContextChanged(e);
+
+ subscription?.Dispose();
+ subscription = null;
+
+ if (DataContext is CheckpointsPageViewModel vm)
+ {
+ subscription = vm.WhenPropertyChanged(m => m.ShowConnectedModelImages)
+ .Subscribe(_ => InvalidateRepeater());
+ }
+ }
+
+ private void InvalidateRepeater()
+ {
+ repeater ??= this.FindControl("FilesRepeater");
+ repeater?.InvalidateArrange();
+ repeater?.InvalidateMeasure();
+
+ foreach (var child in this.GetVisualDescendants().OfType())
+ {
+ child?.InvalidateArrange();
+ child?.InvalidateMeasure();
+ }
+ }
+
private static async void OnDrop(object? sender, DragEventArgs e)
{
var sourceDataContext = (e.Source as Control)?.DataContext;
@@ -29,7 +69,7 @@ public partial class CheckpointsPage : UserControlBase
await folder.OnDrop(e);
}
}
-
+
private static void OnDragExit(object? sender, DragEventArgs e)
{
var sourceDataContext = (e.Source as Control)?.DataContext;
@@ -38,7 +78,7 @@ public partial class CheckpointsPage : UserControlBase
folder.IsCurrentDragTarget = false;
}
}
-
+
private static void OnDragEnter(object? sender, DragEventArgs e)
{
// Only allow Copy or Link as Drop Operations.
@@ -47,9 +87,9 @@ public partial class CheckpointsPage : UserControlBase
// Only allow if the dragged data contains text or filenames.
if (!e.Data.Contains(DataFormats.Text) && !e.Data.Contains(DataFormats.Files))
{
- e.DragEffects = DragDropEffects.None;
+ e.DragEffects = DragDropEffects.None;
}
-
+
// Forward to view model
var sourceDataContext = (e.Source as Control)?.DataContext;
if (sourceDataContext is CheckpointFolder folder)
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/DownloadResourceDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/DownloadResourceDialog.axaml.cs
index 22d34516..3928f022 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/DownloadResourceDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/DownloadResourceDialog.axaml.cs
@@ -1,12 +1,13 @@
-using Avalonia;
-using Avalonia.Controls;
+using Avalonia.Controls;
using Avalonia.Input;
using Avalonia.Markup.Xaml;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class DownloadResourceDialog : UserControl
{
public DownloadResourceDialog()
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/EnvVarsDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/EnvVarsDialog.axaml.cs
index 63068706..3c54d2c2 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/EnvVarsDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/EnvVarsDialog.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia.Markup.Xaml;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class EnvVarsDialog : UserControlBase
{
public EnvVarsDialog()
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/ExceptionDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/ExceptionDialog.axaml.cs
index bbd91996..b2bf44ad 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/ExceptionDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/ExceptionDialog.axaml.cs
@@ -3,15 +3,17 @@ using Avalonia.Interactivity;
using Avalonia.Markup.Xaml;
using FluentAvalonia.UI.Windowing;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class ExceptionDialog : AppWindowBase
{
public ExceptionDialog()
{
InitializeComponent();
-
+
TitleBar.ExtendsContentIntoTitleBar = true;
TitleBar.TitleBarHitTestType = TitleBarHitTestType.Complex;
}
@@ -20,7 +22,7 @@ public partial class ExceptionDialog : AppWindowBase
{
AvaloniaXamlLoader.Load(this);
}
-
+
[SuppressMessage("ReSharper", "UnusedParameter.Local")]
private void ExitButton_OnClick(object? sender, RoutedEventArgs e)
{
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml
index 1ab9fde9..39152677 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml
@@ -106,6 +106,20 @@
Grid.Row="1"
Text="{Binding NegativePrompt}" />
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml.cs
index 7af815c6..fc921fee 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml.cs
@@ -1,10 +1,11 @@
using Avalonia;
using Avalonia.Input;
-using Avalonia.Interactivity;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class ImageViewerDialog : UserControlBase
{
public static readonly StyledProperty IsFooterEnabledProperty = AvaloniaProperty.Register<
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/InferenceConnectionHelpDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/InferenceConnectionHelpDialog.axaml.cs
index fc0e45d5..b2568572 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/InferenceConnectionHelpDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/InferenceConnectionHelpDialog.axaml.cs
@@ -1,9 +1,10 @@
-using Avalonia;
-using Avalonia.Controls;
+using Avalonia.Controls;
using Avalonia.Markup.Xaml;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class InferenceConnectionHelpDialog : UserControl
{
public InferenceConnectionHelpDialog()
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml
index d3792f60..513144cd 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml
@@ -16,34 +16,140 @@
x:Class="StabilityMatrix.Avalonia.Views.Dialogs.InstallerDialog">
+ RowDefinitions="Auto, Auto, Auto, Auto, *, Auto">
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -60,7 +166,7 @@
-
+
-
+
@@ -185,7 +291,7 @@
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/LaunchOptionsDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/LaunchOptionsDialog.axaml.cs
index 84d4d3a6..5dd11fc4 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/LaunchOptionsDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/LaunchOptionsDialog.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia.Controls;
using Avalonia.Markup.Xaml;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class LaunchOptionsDialog : UserControl
{
public LaunchOptionsDialog()
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml
index 1f1e6e62..8df43ca8 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml
@@ -8,13 +8,14 @@
xmlns:packages="clr-namespace:StabilityMatrix.Core.Models.Packages;assembly=StabilityMatrix.Core"
xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
+ xmlns:models="clr-namespace:StabilityMatrix.Core.Models;assembly=StabilityMatrix.Core"
mc:Ignorable="d" d:DesignWidth="700" d:DesignHeight="700"
x:DataType="dialogs:OneClickInstallViewModel"
d:DataContext="{x:Static designData:DesignData.OneClickInstallViewModel}"
x:Class="StabilityMatrix.Avalonia.Views.Dialogs.OneClickInstallDialog">
@@ -37,8 +38,10 @@
Title="Use ComfyUI with Inference"
Subtitle="A new built-in native Stable Diffusion experience, powered by ComfyUI"
ActionButtonContent="{x:Static lang:Resources.Action_Install}"
+ ActionButtonCommand="{Binding InstallComfyForInferenceCommand}"
CloseButtonContent="{x:Static lang:Resources.Action_Close}"
PreferredPlacement="RightTop"
+ Margin="8,0,0,0"
PlacementMargin="0,0,0,0"
TailVisibility="Auto">
@@ -68,16 +71,123 @@
FontSize="24"
Margin="16, 16, 0, 4"/>
+ TextWrapping="Wrap"
+ Margin="16, 0, 0, 4"/>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -114,7 +224,7 @@
FontSize="32"
HorizontalAlignment="Center"
Classes="success"
- Margin="16"
+ Margin="8"
Padding="16, 8, 16, 8" />
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml.cs
index 3e6ff511..d615eddc 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/OneClickInstallDialog.axaml.cs
@@ -3,7 +3,6 @@ using System.Linq;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using Avalonia.Interactivity;
-using Avalonia.Markup.Xaml;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Core.Models.Packages;
@@ -30,7 +29,12 @@ public partial class OneClickInstallDialog : UserControl
var teachingTip =
this.FindControl("InferenceTeachingTip")
?? throw new InvalidOperationException("TeachingTip not found");
- ;
+
+ teachingTip.ActionButtonClick += (_, _) =>
+ {
+ teachingTip.IsOpen = false;
+ };
+
// Find ComfyUI listbox item
var listBox = this.FindControl("PackagesListBox");
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs
index a5f8cf50..7123dfa9 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia.Markup.Xaml;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class PackageImportDialog : UserControlBase
{
public PackageImportDialog()
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PackageModificationDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/PackageModificationDialog.axaml
index 7b9d65fb..cb4de0be 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/PackageModificationDialog.axaml
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/PackageModificationDialog.axaml
@@ -12,7 +12,8 @@
x:Class="StabilityMatrix.Avalonia.Views.Dialogs.PackageModificationDialog">
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml
new file mode 100644
index 00000000..bb7dddda
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml
@@ -0,0 +1,248 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml.cs
new file mode 100644
index 00000000..197a94ad
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/PythonPackagesDialog.axaml.cs
@@ -0,0 +1,13 @@
+using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
+
+namespace StabilityMatrix.Avalonia.Views.Dialogs;
+
+[Transient]
+public partial class PythonPackagesDialog : UserControlBase
+{
+ public PythonPackagesDialog()
+ {
+ InitializeComponent();
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs
index 3e874a83..f9eb3bf1 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia.Markup.Xaml;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class SelectDataDirectoryDialog : UserControlBase
{
public SelectDataDirectoryDialog()
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs
index 33d2b218..94ac9c31 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/UpdateDialog.axaml.cs
@@ -1,8 +1,10 @@
using Avalonia.Controls;
using Avalonia.Markup.Xaml;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Dialogs;
+[Transient]
public partial class UpdateDialog : UserControl
{
public UpdateDialog()
@@ -14,4 +16,4 @@ public partial class UpdateDialog : UserControl
{
AvaloniaXamlLoader.Load(this);
}
-}
\ No newline at end of file
+}
diff --git a/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs b/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs
index f02ea254..e3a9cf7f 100644
--- a/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/FirstLaunchSetupWindow.axaml.cs
@@ -1,13 +1,13 @@
-using System;
-using System.Diagnostics.CodeAnalysis;
-using Avalonia;
+using System.Diagnostics.CodeAnalysis;
using Avalonia.Interactivity;
using Avalonia.Markup.Xaml;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views;
+[Singleton]
public partial class FirstLaunchSetupWindow : AppWindowBase
{
public ContentDialogResult Result { get; private set; }
diff --git a/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs b/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs
index eaa6a883..0e13e9d4 100644
--- a/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Inference/InferenceImageUpscaleView.axaml.cs
@@ -1,7 +1,9 @@
using StabilityMatrix.Avalonia.Controls.Dock;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Inference;
+[Transient]
public partial class InferenceImageUpscaleView : DockUserControlBase
{
public InferenceImageUpscaleView()
diff --git a/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs b/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs
index 11784fe4..d9116275 100644
--- a/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/Inference/InferenceTextToImageView.axaml.cs
@@ -1,7 +1,9 @@
using StabilityMatrix.Avalonia.Controls.Dock;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views.Inference;
+[Transient]
public partial class InferenceTextToImageView : DockUserControlBase
{
public InferenceTextToImageView()
diff --git a/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs b/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs
index 7dd4e5ed..05e4eb76 100644
--- a/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs
+++ b/StabilityMatrix.Avalonia/Views/InferencePage.axaml.cs
@@ -7,9 +7,11 @@ using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.ViewModels;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Views;
+[Singleton]
public partial class InferencePage : UserControlBase
{
private Button? _addButton;
diff --git a/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml b/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml
index e77a3cfa..cd81a8a8 100644
--- a/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml
+++ b/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml
@@ -35,7 +35,7 @@
Grid.Row="0"
Grid.Column="0">
public async Task MoveToAsync(FilePath destinationFile)
{
- await Task.Run(() => Info.MoveTo(destinationFile.FullPath)).ConfigureAwait(false);
+ await Task.Run(() =>
+ {
+ var path = destinationFile.FullPath;
+ if (destinationFile.Exists)
+ {
+ var num = Random.Shared.NextInt64(0, 10000);
+ path = path.Replace(
+ destinationFile.NameWithoutExtension,
+ $"{destinationFile.NameWithoutExtension}_{num}"
+ );
+ }
+
+ Info.MoveTo(path);
+ })
+ .ConfigureAwait(false);
// Return the new path
return destinationFile;
}
diff --git a/StabilityMatrix.Core/Models/GenerationParameters.cs b/StabilityMatrix.Core/Models/GenerationParameters.cs
index fab9f793..c0b5aa9b 100644
--- a/StabilityMatrix.Core/Models/GenerationParameters.cs
+++ b/StabilityMatrix.Core/Models/GenerationParameters.cs
@@ -1,4 +1,6 @@
-using System.Diagnostics.CodeAnalysis;
+using System.ComponentModel.DataAnnotations;
+using System.Diagnostics.CodeAnalysis;
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
using StabilityMatrix.Core.Models.Api.Comfy;
@@ -30,12 +32,26 @@ public partial record GenerationParameters
return false;
}
+ try
+ {
+ generationParameters = Parse(text);
+ }
+ catch (Exception)
+ {
+ generationParameters = null;
+ return false;
+ }
+
+ return true;
+ }
+
+ public static GenerationParameters Parse(string text)
+ {
var lines = text.Split('\n');
if (lines.LastOrDefault() is not { } lastLine)
{
- generationParameters = null;
- return false;
+ throw new ValidationException("Fields line not found");
}
if (lastLine.StartsWith("Steps:") != true)
@@ -45,47 +61,84 @@ public partial record GenerationParameters
if (lastLine.StartsWith("Steps:") != true)
{
- generationParameters = null;
- return false;
+ throw new ValidationException("Unable to locate starting marker of last line");
}
}
// Join lines before last line, split at 'Negative prompt: '
- var joinedLines = string.Join("\n", lines[..^1]);
+ var joinedLines = string.Join("\n", lines[..^1]).Trim();
- var splitFirstPart = joinedLines.Split("Negative prompt: ");
- if (splitFirstPart.Length != 2)
- {
- generationParameters = null;
- return false;
- }
+ var splitFirstPart = joinedLines.Split("Negative prompt: ", 2);
- var positivePrompt = splitFirstPart[0];
- var negativePrompt = splitFirstPart[1];
+ var positivePrompt = splitFirstPart.ElementAtOrDefault(0)?.Trim();
+ var negativePrompt = splitFirstPart.ElementAtOrDefault(1)?.Trim();
// Parse last line
- var match = ParseLastLineRegex().Match(lastLine);
- if (!match.Success)
- {
- generationParameters = null;
- return false;
- }
+ var lineFields = ParseLine(lastLine);
- generationParameters = new GenerationParameters
+ var generationParameters = new GenerationParameters
{
PositivePrompt = positivePrompt,
NegativePrompt = negativePrompt,
- Steps = int.Parse(match.Groups["Steps"].Value),
- Sampler = match.Groups["Sampler"].Value,
- CfgScale = double.Parse(match.Groups["CfgScale"].Value),
- Seed = ulong.Parse(match.Groups["Seed"].Value),
- Height = int.Parse(match.Groups["Height"].Value),
- Width = int.Parse(match.Groups["Width"].Value),
- ModelHash = match.Groups["ModelHash"].Value,
- ModelName = match.Groups["ModelName"].Value,
+ Steps = int.Parse(lineFields.GetValueOrDefault("Steps", "0")),
+ Sampler = lineFields.GetValueOrDefault("Sampler"),
+ CfgScale = double.Parse(lineFields.GetValueOrDefault("CFG scale", "0")),
+ Seed = ulong.Parse(lineFields.GetValueOrDefault("Seed", "0")),
+ ModelHash = lineFields.GetValueOrDefault("Model hash"),
+ ModelName = lineFields.GetValueOrDefault("Model"),
};
- return true;
+ if (lineFields.GetValueOrDefault("Size") is { } size)
+ {
+ var split = size.Split('x', 2);
+ if (split.Length == 2)
+ {
+ generationParameters = generationParameters with
+ {
+ Width = int.Parse(split[0]),
+ Height = int.Parse(split[1])
+ };
+ }
+ }
+
+ return generationParameters;
+ }
+
+ ///
+ /// Parse A1111 metadata fields in a single line where
+ /// fields are separated by commas and key-value pairs are separated by colons.
+ /// i.e. "key1: value1, key2: value2"
+ ///
+ internal static Dictionary ParseLine(string fields)
+ {
+ var dict = new Dictionary();
+
+ // Values main contain commas or colons
+ foreach (var match in ParametersFieldsRegex().Matches(fields).Cast())
+ {
+ if (!match.Success)
+ continue;
+
+ var key = match.Groups[1].Value.Trim();
+ var value = UnquoteValue(match.Groups[2].Value.Trim());
+
+ dict.Add(key, value);
+ }
+
+ return dict;
+ }
+
+ ///
+ /// Unquotes a quoted value field if required
+ ///
+ private static string UnquoteValue(string quotedField)
+ {
+ if (!(quotedField.StartsWith('"') && quotedField.EndsWith('"')))
+ {
+ return quotedField;
+ }
+
+ return JsonNode.Parse(quotedField)?.GetValue() ?? "";
}
///
@@ -126,9 +179,26 @@ public partial record GenerationParameters
return (sampler, scheduler);
}
- // Example: Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1
- [GeneratedRegex(
- """^Steps: (?\d+), Sampler: (?.+?), CFG scale: (?\d+(\.\d+)?), Seed: (?\d+), Size: (?\d+)x(?\d+), Model hash: (?.+?), Model: (?.+)$"""
- )]
- private static partial Regex ParseLastLineRegex();
+ ///
+ /// Return a sample parameters for UI preview
+ ///
+ public static GenerationParameters GetSample()
+ {
+ return new GenerationParameters
+ {
+ PositivePrompt = "(cat:1.2), by artist, detailed, [shaded]",
+ NegativePrompt = "blurry, jpg artifacts",
+ Steps = 30,
+ CfgScale = 7,
+ Width = 640,
+ Height = 896,
+ Seed = 124825529,
+ ModelName = "ExampleMix7",
+ ModelHash = "b899d188a1ac7356bfb9399b2277d5b21712aa360f8f9514fba6fcce021baff7",
+ Sampler = "DPM++ 2M Karras"
+ };
+ }
+
+ [GeneratedRegex("""\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)""")]
+ private static partial Regex ParametersFieldsRegex();
}
diff --git a/StabilityMatrix.Core/Models/HybridModelFile.cs b/StabilityMatrix.Core/Models/HybridModelFile.cs
index 2f178fd0..60fd8d4f 100644
--- a/StabilityMatrix.Core/Models/HybridModelFile.cs
+++ b/StabilityMatrix.Core/Models/HybridModelFile.cs
@@ -30,7 +30,10 @@ public record HybridModelFile
public bool IsRemote => RemoteName != null;
[JsonIgnore]
- public string FileName => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder;
+ public string RelativePath => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder;
+
+ [JsonIgnore]
+ public string FileName => Path.GetFileName(RelativePath);
[JsonIgnore]
public string ShortDisplayName
@@ -47,7 +50,7 @@ public record HybridModelFile
return "Default";
}
- return Path.GetFileNameWithoutExtension(FileName);
+ return Path.GetFileNameWithoutExtension(RelativePath);
}
}
@@ -63,7 +66,7 @@ public record HybridModelFile
public string GetId()
{
- return $"{FileName};{IsNone};{IsDefault}";
+ return $"{RelativePath};{IsNone};{IsDefault}";
}
private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer
@@ -79,14 +82,14 @@ public record HybridModelFile
if (x.GetType() != y.GetType())
return false;
- return Equals(x.FileName, y.FileName)
+ return Equals(x.RelativePath, y.RelativePath)
&& x.IsNone == y.IsNone
&& x.IsDefault == y.IsDefault;
}
public int GetHashCode(HybridModelFile obj)
{
- return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.FileName);
+ return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.RelativePath);
}
}
diff --git a/StabilityMatrix.Core/Models/InstalledPackage.cs b/StabilityMatrix.Core/Models/InstalledPackage.cs
index 4695bb12..557d07ab 100644
--- a/StabilityMatrix.Core/Models/InstalledPackage.cs
+++ b/StabilityMatrix.Core/Models/InstalledPackage.cs
@@ -50,6 +50,7 @@ public class InstalledPackage : IJsonOnDeserialized
public bool UpdateAvailable { get; set; }
public TorchVersion? PreferredTorchVersion { get; set; }
public SharedFolderMethod? PreferredSharedFolderMethod { get; set; }
+ public bool UseSharedOutputFolder { get; set; }
///
/// Get the launch args host option value.
diff --git a/StabilityMatrix.Core/Models/PackageDifficulty.cs b/StabilityMatrix.Core/Models/PackageDifficulty.cs
new file mode 100644
index 00000000..51eacc46
--- /dev/null
+++ b/StabilityMatrix.Core/Models/PackageDifficulty.cs
@@ -0,0 +1,12 @@
+namespace StabilityMatrix.Core.Models;
+
+public enum PackageDifficulty
+{
+ Recommended = 0,
+ Simple = 1,
+ Advanced = 2,
+ Expert = 3,
+ Nightmare = 4,
+ UltraNightmare = 5,
+ Impossible = 999
+}
diff --git a/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs b/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs
index 6c6f1e9c..6037b106 100644
--- a/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs
+++ b/StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs
@@ -12,6 +12,7 @@ public interface IPackageModificationRunner
List ConsoleOutput { get; }
Guid Id { get; }
bool ShowDialogOnStart { get; init; }
+ bool HideCloseButton { get; init; }
string? ModificationCompleteMessage { get; init; }
bool Failed { get; set; }
}
diff --git a/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs b/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs
index b307e161..372fdb5f 100644
--- a/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs
+++ b/StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs
@@ -8,12 +8,19 @@ public class InstallPackageStep : IPackageStep
{
private readonly BasePackage package;
private readonly TorchVersion torchVersion;
+ private readonly DownloadPackageVersionOptions versionOptions;
private readonly string installPath;
- public InstallPackageStep(BasePackage package, TorchVersion torchVersion, string installPath)
+ public InstallPackageStep(
+ BasePackage package,
+ TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
+ string installPath
+ )
{
this.package = package;
this.torchVersion = torchVersion;
+ this.versionOptions = versionOptions;
this.installPath = installPath;
}
@@ -25,9 +32,9 @@ public class InstallPackageStep : IPackageStep
}
await package
- .InstallPackage(installPath, torchVersion, progress, OnConsoleOutput)
+ .InstallPackage(installPath, torchVersion, versionOptions, progress, OnConsoleOutput)
.ConfigureAwait(false);
}
- public string ProgressTitle => "Installing package...";
+ public string ProgressTitle => $"Installing {package.DisplayName}...";
}
diff --git a/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs b/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs
index c546671e..0adbe468 100644
--- a/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs
+++ b/StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs
@@ -54,6 +54,7 @@ public class PackageModificationRunner : IPackageModificationRunner
IsRunning = false;
}
+ public bool HideCloseButton { get; init; }
public string? ModificationCompleteMessage { get; init; }
public bool ShowDialogOnStart { get; init; }
diff --git a/StabilityMatrix.Core/Models/PackageModification/PipStep.cs b/StabilityMatrix.Core/Models/PackageModification/PipStep.cs
new file mode 100644
index 00000000..6d5d0cfa
--- /dev/null
+++ b/StabilityMatrix.Core/Models/PackageModification/PipStep.cs
@@ -0,0 +1,45 @@
+using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+
+namespace StabilityMatrix.Core.Models.PackageModification;
+
+public class PipStep : IPackageStep
+{
+ public required ProcessArgs Args { get; init; }
+ public required DirectoryPath VenvDirectory { get; init; }
+
+ public DirectoryPath? WorkingDirectory { get; init; }
+
+ public IReadOnlyDictionary? EnvironmentVariables { get; init; }
+
+ ///
+ public string ProgressTitle =>
+ Args switch
+ {
+ _ when Args.Contains("install") => "Installing Pip Packages",
+ _ when Args.Contains("uninstall") => "Uninstalling Pip Packages",
+ _ when Args.Contains("-U") || Args.Contains("--upgrade") => "Updating Pip Packages",
+ _ => "Running Pip"
+ };
+
+ ///
+ public async Task ExecuteAsync(IProgress? progress = null)
+ {
+ await using var venvRunner = new PyVenvRunner(VenvDirectory)
+ {
+ WorkingDirectory = WorkingDirectory,
+ EnvironmentVariables = EnvironmentVariables
+ };
+
+ var args = new List { "-m", "pip" };
+ args.AddRange(Args.ToArray());
+
+ venvRunner.RunDetached(args.ToArray(), progress.AsProcessOutputHandler());
+
+ await ProcessRunner.WaitForExitConditionAsync(venvRunner.Process).ConfigureAwait(false);
+ }
+}
diff --git a/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs b/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs
index e7fbfa97..be996cfe 100644
--- a/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs
+++ b/StabilityMatrix.Core/Models/PackageModification/UpdatePackageStep.cs
@@ -9,16 +9,19 @@ public class UpdatePackageStep : IPackageStep
{
private readonly ISettingsManager settingsManager;
private readonly InstalledPackage installedPackage;
+ private readonly DownloadPackageVersionOptions versionOptions;
private readonly BasePackage basePackage;
public UpdatePackageStep(
ISettingsManager settingsManager,
InstalledPackage installedPackage,
+ DownloadPackageVersionOptions versionOptions,
BasePackage basePackage
)
{
this.settingsManager = settingsManager;
this.installedPackage = installedPackage;
+ this.versionOptions = versionOptions;
this.basePackage = basePackage;
}
@@ -33,7 +36,13 @@ public class UpdatePackageStep : IPackageStep
}
var updateResult = await basePackage
- .Update(installedPackage, torchVersion, progress, onConsoleOutput: OnConsoleOutput)
+ .Update(
+ installedPackage,
+ torchVersion,
+ versionOptions,
+ progress,
+ onConsoleOutput: OnConsoleOutput
+ )
.ConfigureAwait(false);
settingsManager.UpdatePackageVersionNumber(installedPackage.Id, updateResult);
diff --git a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs
index 7f729d91..50defb6c 100644
--- a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs
+++ b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs
@@ -2,6 +2,7 @@
using System.Text.Json.Nodes;
using System.Text.RegularExpressions;
using NLog;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.FileInterfaces;
@@ -12,6 +13,7 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class A3WebUI : BaseGitPackage
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -29,6 +31,8 @@ public class A3WebUI : BaseGitPackage
new("https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png");
public string RelativeArgsDefinitionScriptPath => "modules.cmd_args";
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Recommended;
+
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
public A3WebUI(
@@ -61,6 +65,17 @@ public class A3WebUI : BaseGitPackage
[SharedFolderType.AfterDetailer] = new[] { "models/adetailer" }
};
+ public override Dictionary>? SharedOutputFolders =>
+ new()
+ {
+ [SharedOutputType.Extras] = new[] { "outputs/extras-images" },
+ [SharedOutputType.Saved] = new[] { "log/images" },
+ [SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" },
+ [SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" },
+ [SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" },
+ [SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" }
+ };
+
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List LaunchOptions =>
new()
@@ -157,7 +172,7 @@ public class A3WebUI : BaseGitPackage
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
public override IEnumerable AvailableTorchVersions =>
- new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm };
+ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm };
public override async Task GetLatestVersion()
{
@@ -165,15 +180,16 @@ public class A3WebUI : BaseGitPackage
return release.TagName!;
}
+ public override string OutputFolderName => "outputs";
+
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
{
- await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
-
progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
// Setup venv
await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
@@ -191,15 +207,14 @@ public class A3WebUI : BaseGitPackage
case TorchVersion.Rocm:
await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
break;
- case TorchVersion.DirectMl:
- await InstallDirectMlTorch(venvRunner, progress, onConsoleOutput)
- .ConfigureAwait(false);
- break;
default:
throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null);
}
- await venvRunner.PipInstall("httpx==0.24.1", onConsoleOutput);
+ if (versionOptions.VersionTag?.Contains("1.6.0") ?? false)
+ {
+ await venvRunner.PipInstall("httpx==0.24.1", onConsoleOutput);
+ }
// Install requirements file
progress?.Report(
@@ -212,10 +227,6 @@ public class A3WebUI : BaseGitPackage
.PipInstallFromRequirements(requirements, onConsoleOutput, excludes: "torch")
.ConfigureAwait(false);
- progress?.Report(
- new ProgressReport(1f, "Installing Package Requirements", isIndeterminate: false)
- );
-
progress?.Report(new ProgressReport(-1f, "Updating configuration", isIndeterminate: true));
// Create and add {"show_progress_type": "TAESD"} to config.json
@@ -273,7 +284,13 @@ public class A3WebUI : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
- .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput)
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.0.1")
+ .WithTorchVision()
+ .WithTorchExtraIndex("rocm5.1.1"),
+ onConsoleOutput
+ )
.ConfigureAwait(false);
}
}
diff --git a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
index dab562cc..b5984cc8 100644
--- a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
+++ b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
@@ -170,31 +170,45 @@ public abstract class BaseGitPackage : BasePackage
IProgress? progress = null
)
{
- var downloadUrl = GetDownloadUrl(versionOptions);
-
- if (!Directory.Exists(DownloadLocation.Replace($"{Name}.zip", "")))
+ if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag))
+ {
+ await PrerequisiteHelper
+ .RunGit(
+ null,
+ null,
+ "clone",
+ "--branch",
+ versionOptions.VersionTag,
+ GithubUrl,
+ $"\"{installLocation}\""
+ )
+ .ConfigureAwait(false);
+ }
+ else if (!string.IsNullOrWhiteSpace(versionOptions.BranchName))
{
- Directory.CreateDirectory(DownloadLocation.Replace($"{Name}.zip", ""));
+ await PrerequisiteHelper
+ .RunGit(
+ null,
+ null,
+ "clone",
+ "--branch",
+ versionOptions.BranchName,
+ GithubUrl,
+ $"\"{installLocation}\""
+ )
+ .ConfigureAwait(false);
}
- await DownloadService
- .DownloadToFileAsync(downloadUrl, DownloadLocation, progress: progress)
- .ConfigureAwait(false);
+ if (!versionOptions.IsLatest && !string.IsNullOrWhiteSpace(versionOptions.CommitHash))
+ {
+ await PrerequisiteHelper
+ .RunGit(installLocation, null, "checkout", versionOptions.CommitHash)
+ .ConfigureAwait(false);
+ }
progress?.Report(new ProgressReport(100, message: "Download Complete"));
}
- public override async Task InstallPackage(
- string installLocation,
- TorchVersion torchVersion,
- IProgress? progress = null,
- Action? onConsoleOutput = null
- )
- {
- await UnzipPackage(installLocation, progress).ConfigureAwait(false);
- File.Delete(DownloadLocation);
- }
-
protected Task UnzipPackage(string installLocation, IProgress? progress = null)
{
using var zip = ZipFile.OpenRead(DownloadLocation);
@@ -279,6 +293,7 @@ public abstract class BaseGitPackage : BasePackage
public override async Task Update(
InstalledPackage installedPackage,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
bool includePrerelease = false,
Action? onConsoleOutput = null
@@ -287,47 +302,146 @@ public abstract class BaseGitPackage : BasePackage
if (installedPackage.Version == null)
throw new NullReferenceException("Version is null");
- if (installedPackage.Version.IsReleaseMode)
+ if (!Directory.Exists(Path.Combine(installedPackage.FullPath!, ".git")))
{
- var releases = await GetAllReleases().ConfigureAwait(false);
- var latestRelease = releases.First(x => includePrerelease || !x.Prerelease);
+ Logger.Info("not a git repo, initializing...");
+ progress?.Report(
+ new ProgressReport(-1f, "Initializing git repo", isIndeterminate: true)
+ );
+ await PrerequisiteHelper
+ .RunGit(installedPackage.FullPath!, onConsoleOutput, "init")
+ .ConfigureAwait(false);
+ await PrerequisiteHelper
+ .RunGit(
+ installedPackage.FullPath!,
+ onConsoleOutput,
+ "remote",
+ "add",
+ "origin",
+ GithubUrl
+ )
+ .ConfigureAwait(false);
+ }
- await DownloadPackage(
- installedPackage.FullPath,
- new DownloadPackageVersionOptions { VersionTag = latestRelease.TagName },
- progress
+ if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag))
+ {
+ progress?.Report(new ProgressReport(-1f, "Fetching tags...", isIndeterminate: true));
+ await PrerequisiteHelper
+ .RunGit(installedPackage.FullPath!, onConsoleOutput, "fetch", "--tags")
+ .ConfigureAwait(false);
+
+ progress?.Report(
+ new ProgressReport(
+ -1f,
+ $"Checking out {versionOptions.VersionTag}",
+ isIndeterminate: true
+ )
+ );
+ await PrerequisiteHelper
+ .RunGit(
+ installedPackage.FullPath!,
+ onConsoleOutput,
+ "checkout",
+ versionOptions.VersionTag,
+ "--force"
)
.ConfigureAwait(false);
- await InstallPackage(installedPackage.FullPath, torchVersion, progress, onConsoleOutput)
+ await InstallPackage(
+ installedPackage.FullPath!,
+ torchVersion,
+ new DownloadPackageVersionOptions
+ {
+ VersionTag = versionOptions.VersionTag,
+ IsLatest = versionOptions.IsLatest
+ },
+ progress,
+ onConsoleOutput
+ )
.ConfigureAwait(false);
- return new InstalledPackageVersion { InstalledReleaseVersion = latestRelease.TagName };
+ return new InstalledPackageVersion
+ {
+ InstalledReleaseVersion = versionOptions.VersionTag
+ };
}
- // Commit mode
- var allCommits = await GetAllCommits(installedPackage.Version.InstalledBranch)
+ // fetch
+ progress?.Report(new ProgressReport(-1f, "Fetching data...", isIndeterminate: true));
+ await PrerequisiteHelper
+ .RunGit(installedPackage.FullPath!, onConsoleOutput, "fetch")
.ConfigureAwait(false);
- var latestCommit = allCommits?.First();
- if (latestCommit is null || string.IsNullOrEmpty(latestCommit.Sha))
+ if (versionOptions.IsLatest)
{
- throw new Exception("No commits found for branch");
+ // checkout
+ progress?.Report(
+ new ProgressReport(
+ -1f,
+ $"Checking out {installedPackage.Version.InstalledBranch}...",
+ isIndeterminate: true
+ )
+ );
+ await PrerequisiteHelper
+ .RunGit(
+ installedPackage.FullPath!,
+ onConsoleOutput,
+ "checkout",
+ versionOptions.BranchName,
+ "--force"
+ )
+ .ConfigureAwait(false);
+
+ // pull
+ progress?.Report(new ProgressReport(-1f, "Pulling changes...", isIndeterminate: true));
+ await PrerequisiteHelper
+ .RunGit(
+ installedPackage.FullPath!,
+ onConsoleOutput,
+ "pull",
+ "origin",
+ installedPackage.Version.InstalledBranch
+ )
+ .ConfigureAwait(false);
+ }
+ else
+ {
+ // checkout
+ progress?.Report(
+ new ProgressReport(
+ -1f,
+ $"Checking out {installedPackage.Version.InstalledBranch}...",
+ isIndeterminate: true
+ )
+ );
+ await PrerequisiteHelper
+ .RunGit(
+ installedPackage.FullPath!,
+ onConsoleOutput,
+ "checkout",
+ versionOptions.CommitHash,
+ "--force"
+ )
+ .ConfigureAwait(false);
}
- await DownloadPackage(
+ await InstallPackage(
installedPackage.FullPath,
- new DownloadPackageVersionOptions { CommitHash = latestCommit.Sha },
- progress
+ torchVersion,
+ new DownloadPackageVersionOptions
+ {
+ CommitHash = versionOptions.CommitHash,
+ IsLatest = versionOptions.IsLatest
+ },
+ progress,
+ onConsoleOutput
)
.ConfigureAwait(false);
- await InstallPackage(installedPackage.FullPath, torchVersion, progress, onConsoleOutput)
- .ConfigureAwait(false);
return new InstalledPackageVersion
{
- InstalledBranch = installedPackage.Version.InstalledBranch,
- InstalledCommitSha = latestCommit.Sha
+ InstalledBranch = versionOptions.BranchName,
+ InstalledCommitSha = versionOptions.CommitHash
};
}
@@ -372,7 +486,37 @@ public abstract class BaseGitPackage : BasePackage
{
if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink)
{
- StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(this, installDirectory);
+ StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(
+ SharedFolders,
+ installDirectory
+ );
+ }
+ return Task.CompletedTask;
+ }
+
+ public override Task SetupOutputFolderLinks(DirectoryPath installDirectory)
+ {
+ if (SharedOutputFolders is { } sharedOutputFolders)
+ {
+ return StabilityMatrix.Core.Helper.SharedFolders.UpdateLinksForPackage(
+ sharedOutputFolders,
+ SettingsManager.ImagesDirectory,
+ installDirectory,
+ recursiveDelete: true
+ );
+ }
+
+ return Task.CompletedTask;
+ }
+
+ public override Task RemoveOutputFolderLinks(DirectoryPath installDirectory)
+ {
+ if (SharedOutputFolders is { } sharedOutputFolders)
+ {
+ StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(
+ sharedOutputFolders,
+ installDirectory
+ );
}
return Task.CompletedTask;
}
diff --git a/StabilityMatrix.Core/Models/Packages/BasePackage.cs b/StabilityMatrix.Core/Models/Packages/BasePackage.cs
index d9fd2289..6165c697 100644
--- a/StabilityMatrix.Core/Models/Packages/BasePackage.cs
+++ b/StabilityMatrix.Core/Models/Packages/BasePackage.cs
@@ -38,6 +38,14 @@ public abstract class BasePackage
public virtual bool IsInferenceCompatible => false;
+ public abstract string OutputFolderName { get; }
+
+ public abstract IEnumerable AvailableTorchVersions { get; }
+
+ public virtual bool IsCompatible => GetRecommendedTorchVersion() != TorchVersion.Cpu;
+
+ public abstract PackageDifficulty InstallerSortOrder { get; }
+
public abstract Task DownloadPackage(
string installLocation,
DownloadPackageVersionOptions versionOptions,
@@ -47,6 +55,7 @@ public abstract class BasePackage
public abstract Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
);
@@ -63,6 +72,7 @@ public abstract class BasePackage
public abstract Task Update(
InstalledPackage installedPackage,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
bool includePrerelease = false,
Action? onConsoleOutput = null
@@ -93,7 +103,8 @@ public abstract class BasePackage
SharedFolderMethod sharedFolderMethod
);
- public abstract IEnumerable AvailableTorchVersions { get; }
+ public abstract Task SetupOutputFolderLinks(DirectoryPath installDirectory);
+ public abstract Task RemoveOutputFolderLinks(DirectoryPath installDirectory);
public virtual TorchVersion GetRecommendedTorchVersion()
{
@@ -142,7 +153,11 @@ public abstract class BasePackage
/// The shared folders that this package supports.
/// Mapping of to the relative paths from the package root.
///
- public virtual Dictionary>? SharedFolders { get; }
+ public abstract Dictionary>? SharedFolders { get; }
+ public abstract Dictionary<
+ SharedOutputType,
+ IReadOnlyList
+ >? SharedOutputFolders { get; }
public abstract Task GetLatestVersion();
public abstract Task GetAllVersionOptions();
@@ -176,9 +191,15 @@ public abstract class BasePackage
);
await venvRunner
- .PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, onConsoleOutput)
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.0.1")
+ .WithTorchVision("==0.15.2")
+ .WithXFormers("==0.0.20")
+ .WithTorchExtraIndex("cu118"),
+ onConsoleOutput
+ )
.ConfigureAwait(false);
- await venvRunner.PipInstall("xformers==0.0.20", onConsoleOutput).ConfigureAwait(false);
}
protected Task InstallDirectMlTorch(
@@ -191,7 +212,7 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true)
);
- return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput);
+ return venvRunner.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput);
}
protected Task InstallCpuTorch(
@@ -204,6 +225,9 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true)
);
- return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, onConsoleOutput);
+ return venvRunner.PipInstall(
+ new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(),
+ onConsoleOutput
+ );
}
}
diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs
index af3881fb..c81055a3 100644
--- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs
+++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs
@@ -1,6 +1,7 @@
using System.Diagnostics;
using System.Text.RegularExpressions;
using NLog;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.FileInterfaces;
@@ -14,6 +15,7 @@ using YamlDotNet.Serialization.NamingConventions;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class ComfyUI : BaseGitPackage
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -30,6 +32,9 @@ public class ComfyUI : BaseGitPackage
new("https://github.com/comfyanonymous/ComfyUI/raw/master/comfyui_screenshot.png");
public override bool ShouldIgnoreReleases => true;
public override bool IsInferenceCompatible => true;
+ public override string OutputFolderName => "output";
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced;
+
public override SharedFolderMethod RecommendedSharedFolderMethod =>
SharedFolderMethod.Configuration;
@@ -58,6 +63,9 @@ public class ComfyUI : BaseGitPackage
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
};
+ public override Dictionary>? SharedOutputFolders =>
+ new() { [SharedOutputType.Text2Img] = new[] { "output" } };
+
public override List LaunchOptions =>
new List
{
@@ -141,17 +149,23 @@ public class ComfyUI : BaseGitPackage
public override Task GetLatestVersion() => Task.FromResult("master");
public override IEnumerable AvailableTorchVersions =>
- new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm };
+ new[]
+ {
+ TorchVersion.Cpu,
+ TorchVersion.Cuda,
+ TorchVersion.DirectMl,
+ TorchVersion.Rocm,
+ TorchVersion.Mps
+ };
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
{
- await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
-
progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true));
// Setup venv
await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
@@ -165,13 +179,36 @@ public class ComfyUI : BaseGitPackage
await InstallCpuTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
break;
case TorchVersion.Cuda:
- await InstallCudaTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
+ await venvRunner
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("~=2.1.0")
+ .WithTorchVision()
+ .WithXFormers("==0.0.22.post4")
+ .AddArg("--upgrade")
+ .WithTorchExtraIndex("cu121"),
+ onConsoleOutput
+ )
+ .ConfigureAwait(false);
+ break;
+ case TorchVersion.DirectMl:
+ await venvRunner
+ .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
+ .ConfigureAwait(false);
break;
case TorchVersion.Rocm:
await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
break;
- case TorchVersion.DirectMl:
- await InstallDirectMlTorch(venvRunner, progress, onConsoleOutput)
+ case TorchVersion.Mps:
+ await venvRunner
+ .PipInstall(
+ new PipInstallArgs()
+ .AddArg("--pre")
+ .WithTorch()
+ .WithTorchVision()
+ .WithTorchExtraIndex("nightly/cpu"),
+ onConsoleOutput
+ )
.ConfigureAwait(false);
break;
default:
@@ -441,7 +478,13 @@ public class ComfyUI : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
- .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, onConsoleOutput)
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.0.1")
+ .WithTorchVision()
+ .WithTorchExtraIndex("rocm5.6"),
+ onConsoleOutput
+ )
.ConfigureAwait(false);
}
diff --git a/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs b/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs
index ea1e34b5..54eba855 100644
--- a/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs
+++ b/StabilityMatrix.Core/Models/Packages/DankDiffusion.cs
@@ -1,6 +1,7 @@
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Services;
@@ -30,6 +31,20 @@ public class DankDiffusion : BaseGitPackage
public override Uri PreviewImageUri { get; }
+ public override string OutputFolderName { get; }
+ public override PackageDifficulty InstallerSortOrder { get; }
+
+ public override Task InstallPackage(
+ string installLocation,
+ TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null
+ )
+ {
+ throw new NotImplementedException();
+ }
+
public override Task RunPackage(
string installedPackagePath,
string command,
@@ -68,6 +83,12 @@ public class DankDiffusion : BaseGitPackage
public override List LaunchOptions { get; }
+ public override Dictionary>? SharedFolders { get; }
+ public override Dictionary<
+ SharedOutputType,
+ IReadOnlyList
+ >? SharedOutputFolders { get; }
+
public override Task GetLatestVersion()
{
throw new NotImplementedException();
diff --git a/StabilityMatrix.Core/Models/Packages/FocusControlNet.cs b/StabilityMatrix.Core/Models/Packages/FocusControlNet.cs
new file mode 100644
index 00000000..b86d0a44
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Packages/FocusControlNet.cs
@@ -0,0 +1,37 @@
+using System.Diagnostics;
+using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.Cache;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Models.Packages;
+
+[Singleton(typeof(BasePackage))]
+public class FocusControlNet : Fooocus
+{
+ public FocusControlNet(
+ IGithubApiCache githubApi,
+ ISettingsManager settingsManager,
+ IDownloadService downloadService,
+ IPrerequisiteHelper prerequisiteHelper
+ )
+ : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
+
+ public override string Name => "Fooocus-ControlNet-SDXL";
+ public override string DisplayName { get; set; } = "Fooocus-ControlNet";
+ public override string Author => "fenneishi";
+ public override string Blurb =>
+ "Fooocus-ControlNet adds more control to the original Fooocus software.";
+ public override string LicenseType => "GPL-3.0";
+ public override string LicenseUrl =>
+ "https://github.com/fenneishi/Fooocus-ControlNet-SDXL/blob/main/LICENSE";
+ public override string LaunchCommand => "launch.py";
+ public override Uri PreviewImageUri =>
+ new("https://github.com/fenneishi/Fooocus-ControlNet-SDXL/raw/main/asset/canny/snip.png");
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced;
+}
diff --git a/StabilityMatrix.Core/Models/Packages/Fooocus.cs b/StabilityMatrix.Core/Models/Packages/Fooocus.cs
index 424f3553..b4ad8830 100644
--- a/StabilityMatrix.Core/Models/Packages/Fooocus.cs
+++ b/StabilityMatrix.Core/Models/Packages/Fooocus.cs
@@ -1,14 +1,17 @@
using System.Diagnostics;
using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class Fooocus : BaseGitPackage
{
public Fooocus(
@@ -38,6 +41,12 @@ public class Fooocus : BaseGitPackage
public override List LaunchOptions =>
new()
{
+ new LaunchOptionDefinition
+ {
+ Name = "Preset",
+ Type = LaunchOptionType.Bool,
+ Options = { "--preset anime", "--preset realistic" }
+ },
new LaunchOptionDefinition
{
Name = "Port",
@@ -59,6 +68,49 @@ public class Fooocus : BaseGitPackage
Description = "Set the listen interface",
Options = { "--listen" }
},
+ new LaunchOptionDefinition
+ {
+ Name = "Output Directory",
+ Type = LaunchOptionType.String,
+ Description = "Override the output directory",
+ Options = { "--output-directory" }
+ },
+ new()
+ {
+ Name = "VRAM",
+ Type = LaunchOptionType.Bool,
+ InitialValue = HardwareHelper
+ .IterGpuInfo()
+ .Select(gpu => gpu.MemoryLevel)
+ .Max() switch
+ {
+ Level.Low => "--lowvram",
+ Level.Medium => "--normalvram",
+ _ => null
+ },
+ Options = { "--highvram", "--normalvram", "--lowvram", "--novram" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Use DirectML",
+ Type = LaunchOptionType.Bool,
+ Description = "Use pytorch with DirectML support",
+ InitialValue = HardwareHelper.PreferDirectML(),
+ Options = { "--directml" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Disable Xformers",
+ Type = LaunchOptionType.Bool,
+ InitialValue = !HardwareHelper.HasNvidiaGpu(),
+ Options = { "--disable-xformers" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Auto-Launch",
+ Type = LaunchOptionType.Bool,
+ Options = { "--auto-launch" }
+ },
LaunchOptionDefinition.Extras
};
@@ -83,48 +135,59 @@ public class Fooocus : BaseGitPackage
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" }
};
+ public override Dictionary>? SharedOutputFolders =>
+ new() { [SharedOutputType.Text2Img] = new[] { "outputs" } };
+
public override IEnumerable AvailableTorchVersions =>
- new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm };
+ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm };
public override Task GetLatestVersion() => Task.FromResult("main");
public override bool ShouldIgnoreReleases => true;
+ public override string OutputFolderName => "outputs";
+
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Simple;
+
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
{
- await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
var venvRunner = await SetupVenv(installLocation, forceRecreate: true)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
- var torchVersionStr = "cpu";
-
- switch (torchVersion)
+ if (torchVersion == TorchVersion.DirectMl)
{
- case TorchVersion.Cuda:
- torchVersionStr = "cu118";
- break;
- case TorchVersion.Rocm:
- torchVersionStr = "rocm5.4.2";
- break;
- case TorchVersion.Cpu:
- break;
- default:
- throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null);
+ await venvRunner
+ .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
+ .ConfigureAwait(false);
+ }
+ else
+ {
+ var extraIndex = torchVersion switch
+ {
+ TorchVersion.Cpu => "cpu",
+ TorchVersion.Cuda => "cu121",
+ TorchVersion.Rocm => "rocm5.6",
+ _ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
+ };
+
+ await venvRunner
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.1.0")
+ .WithTorchVision("==0.16.0")
+ .WithTorchExtraIndex(extraIndex),
+ onConsoleOutput
+ )
+ .ConfigureAwait(false);
}
-
- await venvRunner
- .PipInstall(
- $"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersionStr}",
- onConsoleOutput
- )
- .ConfigureAwait(false);
var requirements = new FilePath(installLocation, "requirements_versions.txt");
await venvRunner
diff --git a/StabilityMatrix.Core/Models/Packages/FooocusMre.cs b/StabilityMatrix.Core/Models/Packages/FooocusMre.cs
index fb65ce28..2313d444 100644
--- a/StabilityMatrix.Core/Models/Packages/FooocusMre.cs
+++ b/StabilityMatrix.Core/Models/Packages/FooocusMre.cs
@@ -1,14 +1,17 @@
using System.Diagnostics;
using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class FooocusMre : BaseGitPackage
{
public FooocusMre(
@@ -37,6 +40,11 @@ public class FooocusMre : BaseGitPackage
"https://user-images.githubusercontent.com/130458190/265366059-ce430ea0-0995-4067-98dd-cef1d7dc1ab6.png"
);
+ public override string Disclaimer =>
+ "This package may no longer receive updates from its author. It may be removed from Stability Matrix in the future.";
+
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Impossible;
+
public override List LaunchOptions =>
new()
{
@@ -85,6 +93,9 @@ public class FooocusMre : BaseGitPackage
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" }
};
+ public override Dictionary>? SharedOutputFolders =>
+ new() { [SharedOutputType.Text2Img] = new[] { "outputs" } };
+
public override IEnumerable AvailableTorchVersions =>
new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm };
@@ -94,41 +105,47 @@ public class FooocusMre : BaseGitPackage
return release.TagName!;
}
+ public override string OutputFolderName => "outputs";
+
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
{
- await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
var venvRunner = await SetupVenv(installLocation, forceRecreate: true)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
- var torchVersionStr = "cpu";
-
- switch (torchVersion)
+ if (torchVersion == TorchVersion.DirectMl)
{
- case TorchVersion.Cuda:
- torchVersionStr = "cu118";
- break;
- case TorchVersion.Rocm:
- torchVersionStr = "rocm5.4.2";
- break;
- case TorchVersion.Cpu:
- break;
- default:
- throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null);
+ await venvRunner
+ .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
+ .ConfigureAwait(false);
+ }
+ else
+ {
+ var extraIndex = torchVersion switch
+ {
+ TorchVersion.Cpu => "cpu",
+ TorchVersion.Cuda => "cu118",
+ TorchVersion.Rocm => "rocm5.4.2",
+ _ => throw new ArgumentOutOfRangeException(nameof(torchVersion), torchVersion, null)
+ };
+
+ await venvRunner
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.0.1")
+ .WithTorchVision("==0.15.2")
+ .WithTorchExtraIndex(extraIndex),
+ onConsoleOutput
+ )
+ .ConfigureAwait(false);
}
-
- await venvRunner
- .PipInstall(
- $"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersionStr}",
- onConsoleOutput
- )
- .ConfigureAwait(false);
var requirements = new FilePath(installLocation, "requirements_versions.txt");
await venvRunner
diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs
index 0284143c..2d86d438 100644
--- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs
+++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs
@@ -1,5 +1,7 @@
-using System.Text.RegularExpressions;
+using System.Globalization;
+using System.Text.RegularExpressions;
using NLog;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
@@ -11,6 +13,7 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class InvokeAI : BaseGitPackage
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -20,11 +23,11 @@ public class InvokeAI : BaseGitPackage
public override string DisplayName { get; set; } = "InvokeAI";
public override string Author => "invoke-ai";
public override string LicenseType => "Apache-2.0";
-
public override string LicenseUrl => "https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE";
public override string Blurb => "Professional Creative Tools for Stable Diffusion";
public override string LaunchCommand => "invokeai-web";
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Nightmare;
public override IReadOnlyList ExtraLaunchCommands =>
new[]
@@ -43,8 +46,6 @@ public class InvokeAI : BaseGitPackage
"https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/docs/assets/canvas_preview.png"
);
- public override bool ShouldIgnoreReleases => true;
-
public override IEnumerable AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
@@ -60,15 +61,35 @@ public class InvokeAI : BaseGitPackage
public override Dictionary> SharedFolders =>
new()
{
- [SharedFolderType.StableDiffusion] = new[] { RelativeRootPath + "/autoimport/main" },
- [SharedFolderType.Lora] = new[] { RelativeRootPath + "/autoimport/lora" },
+ [SharedFolderType.StableDiffusion] = new[]
+ {
+ Path.Combine(RelativeRootPath, "autoimport", "main")
+ },
+ [SharedFolderType.Lora] = new[]
+ {
+ Path.Combine(RelativeRootPath, "autoimport", "lora")
+ },
[SharedFolderType.TextualInversion] = new[]
{
- RelativeRootPath + "/autoimport/embedding"
+ Path.Combine(RelativeRootPath, "autoimport", "embedding")
},
- [SharedFolderType.ControlNet] = new[] { RelativeRootPath + "/autoimport/controlnet" },
+ [SharedFolderType.ControlNet] = new[]
+ {
+ Path.Combine(RelativeRootPath, "autoimport", "controlnet")
+ }
};
+ public override Dictionary>? SharedOutputFolders =>
+ new()
+ {
+ [SharedOutputType.Text2Img] = new[]
+ {
+ Path.Combine("invokeai-root", "outputs", "images")
+ }
+ };
+
+ public override string OutputFolderName => Path.Combine("invokeai-root", "outputs", "images");
+
// https://github.com/invoke-ai/InvokeAI/blob/main/docs/features/CONFIGURATION.md
public override List LaunchOptions =>
new List
@@ -123,7 +144,11 @@ public class InvokeAI : BaseGitPackage
LaunchOptionDefinition.Extras
};
- public override Task GetLatestVersion() => Task.FromResult("main");
+ public override async Task GetLatestVersion()
+ {
+ var release = await GetLatestRelease().ConfigureAwait(false);
+ return release.TagName!;
+ }
public override IEnumerable AvailableTorchVersions =>
new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm, TorchVersion.Mps };
@@ -138,18 +163,10 @@ public class InvokeAI : BaseGitPackage
return base.GetRecommendedTorchVersion();
}
- public override Task DownloadPackage(
- string installLocation,
- DownloadPackageVersionOptions downloadOptions,
- IProgress? progress = null
- )
- {
- return Task.CompletedTask;
- }
-
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
@@ -157,7 +174,10 @@ public class InvokeAI : BaseGitPackage
// Setup venv
progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
- await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
+ var venvPath = Path.Combine(installLocation, "venv");
+ var exists = Directory.Exists(venvPath);
+
+ await using var venvRunner = new PyVenvRunner(venvPath);
venvRunner.WorkingDirectory = installLocation;
await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false);
@@ -165,29 +185,42 @@ public class InvokeAI : BaseGitPackage
progress?.Report(new ProgressReport(-1f, "Installing Package", isIndeterminate: true));
var pipCommandArgs =
- "InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu";
+ "-e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu";
switch (torchVersion)
{
+ // If has Nvidia Gpu, install CUDA version
case TorchVersion.Cuda:
+ await InstallCudaTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
Logger.Info("Starting InvokeAI install (CUDA)...");
pipCommandArgs =
- "InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117";
+ "-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118";
break;
-
+ // For AMD, Install ROCm version
case TorchVersion.Rocm:
+ await venvRunner
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.0.1")
+ .WithTorchVision()
+ .WithExtraIndex("rocm5.4.2"),
+ onConsoleOutput
+ )
+ .ConfigureAwait(false);
Logger.Info("Starting InvokeAI install (ROCm)...");
pipCommandArgs =
- "InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2";
+ "-e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2";
break;
-
case TorchVersion.Mps:
+ // For Apple silicon, use MPS
Logger.Info("Starting InvokeAI install (MPS)...");
- pipCommandArgs = "InvokeAI --use-pep517";
+ pipCommandArgs = "-e . --use-pep517";
break;
}
- await venvRunner.PipInstall(pipCommandArgs, onConsoleOutput).ConfigureAwait(false);
+ await venvRunner
+ .PipInstall($"{pipCommandArgs}{(exists ? " --upgrade" : "")}", onConsoleOutput)
+ .ConfigureAwait(false);
await venvRunner
.PipInstall("rich packaging python-dotenv", onConsoleOutput)
@@ -207,75 +240,6 @@ public class InvokeAI : BaseGitPackage
progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false));
}
- public override async Task Update(
- InstalledPackage installedPackage,
- TorchVersion torchVersion,
- IProgress? progress = null,
- bool includePrerelease = false,
- Action? onConsoleOutput = null
- )
- {
- progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
-
- if (installedPackage.FullPath is null || installedPackage.Version is null)
- {
- throw new NullReferenceException("Installed package is missing Path and/or Version");
- }
-
- await using var venvRunner = new PyVenvRunner(
- Path.Combine(installedPackage.FullPath, "venv")
- );
- venvRunner.WorkingDirectory = installedPackage.FullPath;
- venvRunner.EnvironmentVariables = GetEnvVars(installedPackage.FullPath);
-
- var latestVersion = await GetUpdateVersion(installedPackage).ConfigureAwait(false);
- var isReleaseMode = installedPackage.Version.IsReleaseMode;
-
- var downloadUrl = isReleaseMode
- ? $"https://github.com/invoke-ai/InvokeAI/archive/{latestVersion}.zip"
- : $"https://github.com/invoke-ai/InvokeAI/archive/refs/heads/{installedPackage.Version.InstalledBranch}.zip";
-
- var gpus = HardwareHelper.IterGpuInfo().ToList();
-
- progress?.Report(new ProgressReport(-1f, "Installing Package", isIndeterminate: true));
-
- var pipCommandArgs =
- $"\"InvokeAI @ {downloadUrl}\" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu --upgrade";
-
- switch (torchVersion)
- {
- // If has Nvidia Gpu, install CUDA version
- case TorchVersion.Cuda:
- Logger.Info("Starting InvokeAI install (CUDA)...");
- pipCommandArgs =
- $"\"InvokeAI[xformers] @ {downloadUrl}\" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 --upgrade";
- break;
- // For AMD, Install ROCm version
- case TorchVersion.Rocm:
- Logger.Info("Starting InvokeAI install (ROCm)...");
- pipCommandArgs =
- $"\"InvokeAI @ {downloadUrl}\" --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 --upgrade";
- break;
- case TorchVersion.Mps:
- // For Apple silicon, use MPS
- Logger.Info("Starting InvokeAI install (MPS)...");
- pipCommandArgs = $"\"InvokeAI @ {downloadUrl}\" --use-pep517 --upgrade";
- break;
- }
-
- await venvRunner.PipInstall(pipCommandArgs, onConsoleOutput).ConfigureAwait(false);
-
- progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false));
-
- return isReleaseMode
- ? new InstalledPackageVersion { InstalledReleaseVersion = latestVersion }
- : new InstalledPackageVersion
- {
- InstalledBranch = installedPackage.Version.InstalledBranch,
- InstalledCommitSha = latestVersion
- };
- }
-
public override Task RunPackage(
string installedPackagePath,
string command,
@@ -283,27 +247,6 @@ public class InvokeAI : BaseGitPackage
Action? onConsoleOutput
) => RunInvokeCommand(installedPackagePath, command, arguments, true, onConsoleOutput);
- private async Task GetUpdateVersion(
- InstalledPackage installedPackage,
- bool includePrerelease = false
- )
- {
- if (installedPackage.Version == null)
- throw new NullReferenceException("Installed package version is null");
-
- if (installedPackage.Version.IsReleaseMode)
- {
- var releases = await GetAllReleases().ConfigureAwait(false);
- var latestRelease = releases.First(x => includePrerelease || !x.Prerelease);
- return latestRelease.TagName;
- }
-
- var allCommits = await GetAllCommits(installedPackage.Version.InstalledBranch)
- .ConfigureAwait(false);
- var latestCommit = allCommits.First();
- return latestCommit.Sha;
- }
-
private async Task RunInvokeCommand(
string installedPackagePath,
string command,
@@ -317,7 +260,6 @@ public class InvokeAI : BaseGitPackage
arguments = command switch
{
"invokeai-configure" => "--yes --skip-sd-weights",
- "invokeai-model-install" => "--yes",
_ => arguments
};
@@ -340,6 +282,21 @@ public class InvokeAI : BaseGitPackage
// above the minimum in invokeai.frontend.install.widgets
var code = $"""
+ try:
+ import os
+ import shutil
+ from invokeai.frontend.install import widgets
+
+ _min_cols = widgets.MIN_COLS
+ _min_lines = widgets.MIN_LINES
+
+ static_size_fn = lambda: os.terminal_size((_min_cols, _min_lines))
+ shutil.get_terminal_size = static_size_fn
+ widgets.get_terminal_size = static_size_fn
+ except Exception as e:
+ import warnings
+ warnings.warn('Could not patch terminal size for InvokeAI' + str(e))
+
import sys
from {split[0]} import {split[1]}
sys.exit({split[1]}())
diff --git a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs
new file mode 100644
index 00000000..ed9d92fc
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs
@@ -0,0 +1,251 @@
+using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.Cache;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Models.Packages;
+
+[Singleton(typeof(BasePackage))]
+public class KohyaSs : BaseGitPackage
+{
+ public KohyaSs(
+ IGithubApiCache githubApi,
+ ISettingsManager settingsManager,
+ IDownloadService downloadService,
+ IPrerequisiteHelper prerequisiteHelper
+ )
+ : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
+
+ public override string Name => "kohya_ss";
+ public override string DisplayName { get; set; } = "kohya_ss";
+ public override string Author => "bmaltais";
+ public override string Blurb =>
+ "A Windows-focused Gradio GUI for Kohya's Stable Diffusion trainers";
+ public override string LicenseType => "Apache-2.0";
+ public override string LicenseUrl =>
+ "https://github.com/bmaltais/kohya_ss/blob/master/LICENSE.md";
+ public override string LaunchCommand => "kohya_gui.py";
+
+ public override Uri PreviewImageUri =>
+ new(
+ "https://camo.githubusercontent.com/2170d2204816f428eec57ff87218f06344e0b4d91966343a6c5f0a76df91ec75/68747470733a2f2f696d672e796f75747562652e636f6d2f76692f6b35696d713031757655592f302e6a7067"
+ );
+ public override string OutputFolderName => string.Empty;
+
+ public override bool IsCompatible => HardwareHelper.HasNvidiaGpu();
+
+ public override TorchVersion GetRecommendedTorchVersion() => TorchVersion.Cuda;
+
+ public override string Disclaimer =>
+ "Nvidia GPU with at least 8GB VRAM is recommended. May be unstable on Linux.";
+
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.UltraNightmare;
+
+ public override bool OfferInOneClickInstaller => false;
+ public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.None;
+ public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.Cuda };
+ public override IEnumerable AvailableSharedFolderMethods =>
+ new[] { SharedFolderMethod.None };
+
+ public override List LaunchOptions =>
+ new()
+ {
+ new LaunchOptionDefinition
+ {
+ Name = "Listen Address",
+ Type = LaunchOptionType.String,
+ DefaultValue = "127.0.0.1",
+ Options = new List { "--listen" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Port",
+ Type = LaunchOptionType.String,
+ Options = new List { "--port" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Username",
+ Type = LaunchOptionType.String,
+ Options = new List { "--username" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Password",
+ Type = LaunchOptionType.String,
+ Options = new List { "--password" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Auto-Launch Browser",
+ Type = LaunchOptionType.Bool,
+ Options = new List { "--inbrowser" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Share",
+ Type = LaunchOptionType.Bool,
+ Options = new List { "--share" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Headless",
+ Type = LaunchOptionType.Bool,
+ Options = new List { "--headless" }
+ },
+ new LaunchOptionDefinition
+ {
+ Name = "Language",
+ Type = LaunchOptionType.String,
+ Options = new List { "--language" }
+ },
+ LaunchOptionDefinition.Extras
+ };
+
+ public override async Task InstallPackage(
+ string installLocation,
+ TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null
+ )
+ {
+ if (Compat.IsWindows)
+ {
+ progress?.Report(
+ new ProgressReport(-1f, "Installing prerequisites...", isIndeterminate: true)
+ );
+ await PrerequisiteHelper.InstallTkinterIfNecessary(progress).ConfigureAwait(false);
+ }
+
+ progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
+ // Setup venv
+ await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
+ venvRunner.WorkingDirectory = installLocation;
+ await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false);
+
+ if (Compat.IsWindows)
+ {
+ var setupSmPath = Path.Combine(installLocation, "setup", "setup_sm.py");
+ var setupText = """
+ import setup_windows
+ import setup_common
+
+ setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=False)
+ setup_windows.sync_bits_and_bytes_files()
+ setup_common.configure_accelerate(run_accelerate=False)
+ """;
+ await File.WriteAllTextAsync(setupSmPath, setupText).ConfigureAwait(false);
+
+ // Install
+ venvRunner.RunDetached("setup/setup_sm.py", onConsoleOutput);
+ await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false);
+ }
+ else if (Compat.IsLinux)
+ {
+ venvRunner.RunDetached(
+ "setup/setup_linux.py --platform-requirements-file=requirements_linux.txt --no_run_accelerate",
+ onConsoleOutput
+ );
+ await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false);
+ }
+ }
+
+ public override async Task RunPackage(
+ string installedPackagePath,
+ string command,
+ string arguments,
+ Action? onConsoleOutput
+ )
+ {
+ await SetupVenv(installedPackagePath).ConfigureAwait(false);
+
+ // update gui files to point to venv accelerate
+ var filesToUpdate = new[]
+ {
+ "lora_gui.py",
+ "dreambooth_gui.py",
+ "textual_inversion_gui.py",
+ Path.Combine("library", "wd14_caption_gui.py"),
+ "finetune_gui.py"
+ };
+
+ foreach (var file in filesToUpdate)
+ {
+ var path = Path.Combine(installedPackagePath, file);
+ var text = await File.ReadAllTextAsync(path).ConfigureAwait(false);
+ var replacementAcceleratePath = Compat.IsWindows
+ ? @".\\venv\\scripts\\accelerate"
+ : "./venv/bin/accelerate";
+ text = text.Replace(
+ "run_cmd = f'accelerate launch",
+ $"run_cmd = f'{replacementAcceleratePath} launch"
+ );
+ await File.WriteAllTextAsync(path, text).ConfigureAwait(false);
+ }
+
+ void HandleConsoleOutput(ProcessOutput s)
+ {
+ onConsoleOutput?.Invoke(s);
+
+ if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase))
+ return;
+
+ var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
+ var match = regex.Match(s.Text);
+ if (!match.Success)
+ return;
+
+ WebUrl = match.Value;
+ OnStartupComplete(WebUrl);
+ }
+
+ var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
+
+ VenvRunner.EnvironmentVariables = GetEnvVars(installedPackagePath);
+ VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
+ }
+
+ public override Dictionary>? SharedFolders { get; }
+ public override Dictionary<
+ SharedOutputType,
+ IReadOnlyList
+ >? SharedOutputFolders { get; }
+
+ public override async Task GetLatestVersion()
+ {
+ var release = await GetLatestRelease().ConfigureAwait(false);
+ return release.TagName!;
+ }
+
+ private Dictionary GetEnvVars(string installDirectory)
+ {
+ // Set additional required environment variables
+ var env = new Dictionary();
+ if (SettingsManager.Settings.EnvironmentVariables is not null)
+ {
+ env.Update(SettingsManager.Settings.EnvironmentVariables);
+ }
+
+ if (!Compat.IsWindows)
+ return env;
+
+ var tkPath = Path.Combine(
+ SettingsManager.LibraryDir,
+ "Assets",
+ "Python310",
+ "tcl",
+ "tcl8.6"
+ );
+ env["TCL_LIBRARY"] = tkPath;
+ env["TK_LIBRARY"] = tkPath;
+
+ return env;
+ }
+}
diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionDirectMl.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionDirectMl.cs
new file mode 100644
index 00000000..3d5bf19b
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionDirectMl.cs
@@ -0,0 +1,251 @@
+using System.Diagnostics.CodeAnalysis;
+using System.Text.Json.Nodes;
+using System.Text.RegularExpressions;
+using NLog;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.Cache;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Models.Packages;
+
+[Singleton(typeof(BasePackage))]
+public class StableDiffusionDirectMl : BaseGitPackage
+{
+ private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
+
+ public override string Name => "stable-diffusion-webui-directml";
+ public override string DisplayName { get; set; } = "Stable Diffusion Web UI";
+ public override string Author => "lshqqytiger";
+ public override string LicenseType => "AGPL-3.0";
+ public override string LicenseUrl =>
+ "https://github.com/lshqqytiger/stable-diffusion-webui-directml/blob/master/LICENSE.txt";
+ public override string Blurb =>
+ "A fork of Automatic1111's Stable Diffusion WebUI with DirectML support";
+ public override string LaunchCommand => "launch.py";
+ public override Uri PreviewImageUri =>
+ new(
+ "https://github.com/lshqqytiger/stable-diffusion-webui-directml/raw/master/screenshot.png"
+ );
+
+ public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
+
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Recommended;
+
+ public StableDiffusionDirectMl(
+ IGithubApiCache githubApi,
+ ISettingsManager settingsManager,
+ IDownloadService downloadService,
+ IPrerequisiteHelper prerequisiteHelper
+ )
+ : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
+
+ public override Dictionary> SharedFolders =>
+ new()
+ {
+ [SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" },
+ [SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" },
+ [SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" },
+ [SharedFolderType.SwinIR] = new[] { "models/SwinIR" },
+ [SharedFolderType.Lora] = new[] { "models/Lora" },
+ [SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" },
+ [SharedFolderType.ApproxVAE] = new[] { "models/VAE-approx" },
+ [SharedFolderType.VAE] = new[] { "models/VAE" },
+ [SharedFolderType.DeepDanbooru] = new[] { "models/deepbooru" },
+ [SharedFolderType.Karlo] = new[] { "models/karlo" },
+ [SharedFolderType.TextualInversion] = new[] { "embeddings" },
+ [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
+ [SharedFolderType.ControlNet] = new[] { "models/ControlNet" },
+ [SharedFolderType.Codeformer] = new[] { "models/Codeformer" },
+ [SharedFolderType.LDSR] = new[] { "models/LDSR" },
+ [SharedFolderType.AfterDetailer] = new[] { "models/adetailer" }
+ };
+
+ public override Dictionary>? SharedOutputFolders =>
+ new()
+ {
+ [SharedOutputType.Extras] = new[] { "outputs/extras-images" },
+ [SharedOutputType.Saved] = new[] { "log/images" },
+ [SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" },
+ [SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" },
+ [SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" },
+ [SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" }
+ };
+
+ [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
+ public override List LaunchOptions =>
+ new()
+ {
+ new()
+ {
+ Name = "Host",
+ Type = LaunchOptionType.String,
+ DefaultValue = "localhost",
+ Options = new() { "--server-name" }
+ },
+ new()
+ {
+ Name = "Port",
+ Type = LaunchOptionType.String,
+ DefaultValue = "7860",
+ Options = new() { "--port" }
+ },
+ new()
+ {
+ Name = "VRAM",
+ Type = LaunchOptionType.Bool,
+ InitialValue = HardwareHelper
+ .IterGpuInfo()
+ .Select(gpu => gpu.MemoryLevel)
+ .Max() switch
+ {
+ Level.Low => "--lowvram",
+ Level.Medium => "--medvram",
+ _ => null
+ },
+ Options = new() { "--lowvram", "--medvram", "--medvram-sdxl" }
+ },
+ new()
+ {
+ Name = "Xformers",
+ Type = LaunchOptionType.Bool,
+ InitialValue = HardwareHelper.HasNvidiaGpu(),
+ Options = new() { "--xformers" }
+ },
+ new()
+ {
+ Name = "API",
+ Type = LaunchOptionType.Bool,
+ InitialValue = true,
+ Options = new() { "--api" }
+ },
+ new()
+ {
+ Name = "Auto Launch Web UI",
+ Type = LaunchOptionType.Bool,
+ InitialValue = false,
+ Options = new() { "--autolaunch" }
+ },
+ new()
+ {
+ Name = "Skip Torch CUDA Check",
+ Type = LaunchOptionType.Bool,
+ InitialValue = !HardwareHelper.HasNvidiaGpu(),
+ Options = new() { "--skip-torch-cuda-test" }
+ },
+ new()
+ {
+ Name = "Skip Python Version Check",
+ Type = LaunchOptionType.Bool,
+ InitialValue = true,
+ Options = new() { "--skip-python-version-check" }
+ },
+ new()
+ {
+ Name = "No Half",
+ Type = LaunchOptionType.Bool,
+ Description = "Do not switch the model to 16-bit floats",
+ InitialValue = HardwareHelper.HasAmdGpu(),
+ Options = new() { "--no-half" }
+ },
+ new()
+ {
+ Name = "Skip SD Model Download",
+ Type = LaunchOptionType.Bool,
+ InitialValue = false,
+ Options = new() { "--no-download-sd-model" }
+ },
+ new()
+ {
+ Name = "Skip Install",
+ Type = LaunchOptionType.Bool,
+ Options = new() { "--skip-install" }
+ },
+ LaunchOptionDefinition.Extras
+ };
+
+ public override IEnumerable AvailableSharedFolderMethods =>
+ new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
+
+ public override IEnumerable AvailableTorchVersions =>
+ new[] { TorchVersion.Cpu, TorchVersion.DirectMl };
+
+ public override Task GetLatestVersion() => Task.FromResult("master");
+
+ public override bool ShouldIgnoreReleases => true;
+
+ public override string OutputFolderName => "outputs";
+
+ public override async Task InstallPackage(
+ string installLocation,
+ TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null
+ )
+ {
+ progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
+ // Setup venv
+ await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
+ venvRunner.WorkingDirectory = installLocation;
+ await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false);
+
+ switch (torchVersion)
+ {
+ case TorchVersion.DirectMl:
+ await InstallDirectMlTorch(venvRunner, progress, onConsoleOutput)
+ .ConfigureAwait(false);
+ break;
+ case TorchVersion.Cpu:
+ await InstallCpuTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
+ break;
+ }
+
+ // Install requirements file
+ progress?.Report(
+ new ProgressReport(-1f, "Installing Package Requirements", isIndeterminate: true)
+ );
+ Logger.Info("Installing requirements_versions.txt");
+
+ var requirements = new FilePath(installLocation, "requirements_versions.txt");
+ await venvRunner
+ .PipInstallFromRequirements(requirements, onConsoleOutput, excludes: "torch")
+ .ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false));
+ }
+
+ public override async Task RunPackage(
+ string installedPackagePath,
+ string command,
+ string arguments,
+ Action? onConsoleOutput
+ )
+ {
+ await SetupVenv(installedPackagePath).ConfigureAwait(false);
+
+ void HandleConsoleOutput(ProcessOutput s)
+ {
+ onConsoleOutput?.Invoke(s);
+
+ if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase))
+ return;
+
+ var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
+ var match = regex.Match(s.Text);
+ if (!match.Success)
+ return;
+
+ WebUrl = match.Value;
+ OnStartupComplete(WebUrl);
+ }
+
+ var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
+
+ VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
+ }
+}
diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs
new file mode 100644
index 00000000..6189d40e
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs
@@ -0,0 +1,277 @@
+using System.Diagnostics.CodeAnalysis;
+using System.Text.Json.Nodes;
+using System.Text.RegularExpressions;
+using NLog;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.Cache;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Models.Packages;
+
+[Singleton(typeof(BasePackage))]
+public class StableDiffusionUx : BaseGitPackage
+{
+ private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
+
+ public override string Name => "stable-diffusion-webui-ux";
+ public override string DisplayName { get; set; } = "Stable Diffusion Web UI-UX";
+ public override string Author => "anapnoe";
+ public override string LicenseType => "AGPL-3.0";
+ public override string LicenseUrl =>
+ "https://github.com/anapnoe/stable-diffusion-webui-ux/blob/master/LICENSE.txt";
+ public override string Blurb =>
+ "A pixel perfect design, mobile friendly, customizable interface that adds accessibility, "
+ + "ease of use and extended functionallity to the stable diffusion web ui.";
+ public override string LaunchCommand => "launch.py";
+ public override Uri PreviewImageUri =>
+ new(
+ "https://raw.githubusercontent.com/anapnoe/stable-diffusion-webui-ux/master/screenshot.png"
+ );
+
+ public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
+
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Simple;
+
+ public StableDiffusionUx(
+ IGithubApiCache githubApi,
+ ISettingsManager settingsManager,
+ IDownloadService downloadService,
+ IPrerequisiteHelper prerequisiteHelper
+ )
+ : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
+
+ public override Dictionary> SharedFolders =>
+ new()
+ {
+ [SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" },
+ [SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" },
+ [SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" },
+ [SharedFolderType.SwinIR] = new[] { "models/SwinIR" },
+ [SharedFolderType.Lora] = new[] { "models/Lora" },
+ [SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" },
+ [SharedFolderType.ApproxVAE] = new[] { "models/VAE-approx" },
+ [SharedFolderType.VAE] = new[] { "models/VAE" },
+ [SharedFolderType.DeepDanbooru] = new[] { "models/deepbooru" },
+ [SharedFolderType.Karlo] = new[] { "models/karlo" },
+ [SharedFolderType.TextualInversion] = new[] { "embeddings" },
+ [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
+ [SharedFolderType.ControlNet] = new[] { "models/ControlNet" },
+ [SharedFolderType.Codeformer] = new[] { "models/Codeformer" },
+ [SharedFolderType.LDSR] = new[] { "models/LDSR" },
+ [SharedFolderType.AfterDetailer] = new[] { "models/adetailer" }
+ };
+
+ public override Dictionary>? SharedOutputFolders =>
+ new()
+ {
+ [SharedOutputType.Extras] = new[] { "outputs/extras-images" },
+ [SharedOutputType.Saved] = new[] { "log/images" },
+ [SharedOutputType.Img2Img] = new[] { "outputs/img2img-images" },
+ [SharedOutputType.Text2Img] = new[] { "outputs/txt2img-images" },
+ [SharedOutputType.Img2ImgGrids] = new[] { "outputs/img2img-grids" },
+ [SharedOutputType.Text2ImgGrids] = new[] { "outputs/txt2img-grids" }
+ };
+
+ [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
+ public override List LaunchOptions =>
+ new()
+ {
+ new()
+ {
+ Name = "Host",
+ Type = LaunchOptionType.String,
+ DefaultValue = "localhost",
+ Options = new() { "--server-name" }
+ },
+ new()
+ {
+ Name = "Port",
+ Type = LaunchOptionType.String,
+ DefaultValue = "7860",
+ Options = new() { "--port" }
+ },
+ new()
+ {
+ Name = "VRAM",
+ Type = LaunchOptionType.Bool,
+ InitialValue = HardwareHelper
+ .IterGpuInfo()
+ .Select(gpu => gpu.MemoryLevel)
+ .Max() switch
+ {
+ Level.Low => "--lowvram",
+ Level.Medium => "--medvram",
+ _ => null
+ },
+ Options = new() { "--lowvram", "--medvram", "--medvram-sdxl" }
+ },
+ new()
+ {
+ Name = "Xformers",
+ Type = LaunchOptionType.Bool,
+ InitialValue = HardwareHelper.HasNvidiaGpu(),
+ Options = new() { "--xformers" }
+ },
+ new()
+ {
+ Name = "API",
+ Type = LaunchOptionType.Bool,
+ InitialValue = true,
+ Options = new() { "--api" }
+ },
+ new()
+ {
+ Name = "Auto Launch Web UI",
+ Type = LaunchOptionType.Bool,
+ InitialValue = false,
+ Options = new() { "--autolaunch" }
+ },
+ new()
+ {
+ Name = "Skip Torch CUDA Check",
+ Type = LaunchOptionType.Bool,
+ InitialValue = !HardwareHelper.HasNvidiaGpu(),
+ Options = new() { "--skip-torch-cuda-test" }
+ },
+ new()
+ {
+ Name = "Skip Python Version Check",
+ Type = LaunchOptionType.Bool,
+ InitialValue = true,
+ Options = new() { "--skip-python-version-check" }
+ },
+ new()
+ {
+ Name = "No Half",
+ Type = LaunchOptionType.Bool,
+ Description = "Do not switch the model to 16-bit floats",
+ InitialValue = HardwareHelper.HasAmdGpu(),
+ Options = new() { "--no-half" }
+ },
+ new()
+ {
+ Name = "Skip SD Model Download",
+ Type = LaunchOptionType.Bool,
+ InitialValue = false,
+ Options = new() { "--no-download-sd-model" }
+ },
+ new()
+ {
+ Name = "Skip Install",
+ Type = LaunchOptionType.Bool,
+ Options = new() { "--skip-install" }
+ },
+ LaunchOptionDefinition.Extras
+ };
+
+ public override IEnumerable AvailableSharedFolderMethods =>
+ new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
+
+ public override IEnumerable AvailableTorchVersions =>
+ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm };
+
+ public override Task GetLatestVersion() => Task.FromResult("master");
+
+ public override bool ShouldIgnoreReleases => true;
+
+ public override string OutputFolderName => "outputs";
+
+ public override async Task InstallPackage(
+ string installLocation,
+ TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null
+ )
+ {
+ progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
+ // Setup venv
+ await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
+ venvRunner.WorkingDirectory = installLocation;
+ await venvRunner.Setup(true, onConsoleOutput).ConfigureAwait(false);
+
+ switch (torchVersion)
+ {
+ case TorchVersion.Cpu:
+ await InstallCpuTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
+ break;
+ case TorchVersion.Cuda:
+ await InstallCudaTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
+ break;
+ case TorchVersion.Rocm:
+ await InstallRocmTorch(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
+ break;
+ }
+
+ // Install requirements file
+ progress?.Report(
+ new ProgressReport(-1f, "Installing Package Requirements", isIndeterminate: true)
+ );
+ Logger.Info("Installing requirements_versions.txt");
+
+ var requirements = new FilePath(installLocation, "requirements_versions.txt");
+ await venvRunner
+ .PipInstallFromRequirements(requirements, onConsoleOutput, excludes: "torch")
+ .ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false));
+ }
+
+ public override async Task RunPackage(
+ string installedPackagePath,
+ string command,
+ string arguments,
+ Action? onConsoleOutput
+ )
+ {
+ await SetupVenv(installedPackagePath).ConfigureAwait(false);
+
+ void HandleConsoleOutput(ProcessOutput s)
+ {
+ onConsoleOutput?.Invoke(s);
+
+ if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase))
+ return;
+
+ var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
+ var match = regex.Match(s.Text);
+ if (!match.Success)
+ return;
+
+ WebUrl = match.Value;
+ OnStartupComplete(WebUrl);
+ }
+
+ var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
+
+ VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
+ }
+
+ private async Task InstallRocmTorch(
+ PyVenvRunner venvRunner,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null
+ )
+ {
+ progress?.Report(
+ new ProgressReport(-1f, "Installing PyTorch for ROCm", isIndeterminate: true)
+ );
+
+ await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
+
+ await venvRunner
+ .PipInstall(
+ new PipInstallArgs()
+ .WithTorch("==2.0.1")
+ .WithTorchVision()
+ .WithTorchExtraIndex("rocm5.1.1"),
+ onConsoleOutput
+ )
+ .ConfigureAwait(false);
+ }
+}
diff --git a/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs b/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs
index 5be7baf0..6468a831 100644
--- a/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs
+++ b/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs
@@ -26,6 +26,10 @@ public class UnknownPackage : BasePackage
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
+ public override string OutputFolderName { get; }
+
+ public override PackageDifficulty InstallerSortOrder { get; }
+
public override Task DownloadPackage(
string installLocation,
DownloadPackageVersionOptions versionOptions,
@@ -39,6 +43,7 @@ public class UnknownPackage : BasePackage
public override Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
@@ -83,6 +88,16 @@ public class UnknownPackage : BasePackage
throw new NotImplementedException();
}
+ public override Task SetupOutputFolderLinks(DirectoryPath installDirectory)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override Task RemoveOutputFolderLinks(DirectoryPath installDirectory)
+ {
+ throw new NotImplementedException();
+ }
+
public override IEnumerable AvailableTorchVersions =>
new[] { TorchVersion.Cuda, TorchVersion.Cpu, TorchVersion.Rocm, TorchVersion.DirectMl };
@@ -108,6 +123,7 @@ public class UnknownPackage : BasePackage
public override Task Update(
InstalledPackage installedPackage,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
bool includePrerelease = false,
Action? onConsoleOutput = null
@@ -122,6 +138,12 @@ public class UnknownPackage : BasePackage
public override List LaunchOptions => new();
+ public override Dictionary>? SharedFolders { get; }
+ public override Dictionary<
+ SharedOutputType,
+ IReadOnlyList
+ >? SharedOutputFolders { get; }
+
public override Task GetLatestVersion() => Task.FromResult(string.Empty);
public override Task GetAllVersionOptions() =>
diff --git a/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs b/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs
index 08fd6fe9..7055fe5d 100644
--- a/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs
+++ b/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs
@@ -4,6 +4,7 @@ using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.RegularExpressions;
using NLog;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.FileInterfaces;
@@ -14,6 +15,7 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class VladAutomatic : BaseGitPackage
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -32,9 +34,10 @@ public class VladAutomatic : BaseGitPackage
public override bool ShouldIgnoreReleases => true;
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Expert;
public override IEnumerable AvailableTorchVersions =>
- new[] { TorchVersion.Cpu, TorchVersion.Rocm, TorchVersion.DirectMl, TorchVersion.Cuda };
+ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm };
public VladAutomatic(
IGithubApiCache githubApi,
@@ -67,6 +70,19 @@ public class VladAutomatic : BaseGitPackage
[SharedFolderType.ControlNet] = new[] { "models/ControlNet" }
};
+ public override Dictionary>? SharedOutputFolders =>
+ new()
+ {
+ [SharedOutputType.Text2Img] = new[] { "outputs/text" },
+ [SharedOutputType.Img2Img] = new[] { "outputs/image" },
+ [SharedOutputType.Extras] = new[] { "outputs/extras" },
+ [SharedOutputType.Img2ImgGrids] = new[] { "outputs/grids" },
+ [SharedOutputType.Text2ImgGrids] = new[] { "outputs/grids" },
+ [SharedOutputType.Saved] = new[] { "outputs/save" },
+ };
+
+ public override string OutputFolderName => "outputs";
+
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List LaunchOptions =>
new()
@@ -161,6 +177,7 @@ public class VladAutomatic : BaseGitPackage
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
@@ -225,6 +242,7 @@ public class VladAutomatic : BaseGitPackage
await PrerequisiteHelper
.RunGit(
installDir.Parent ?? "",
+ null,
"clone",
"https://github.com/vladmandic/automatic",
installDir.Name
@@ -232,7 +250,7 @@ public class VladAutomatic : BaseGitPackage
.ConfigureAwait(false);
await PrerequisiteHelper
- .RunGit(installLocation, "checkout", downloadOptions.CommitHash)
+ .RunGit(installLocation, null, "checkout", downloadOptions.CommitHash)
.ConfigureAwait(false);
}
else if (!string.IsNullOrWhiteSpace(downloadOptions.BranchName))
@@ -240,6 +258,7 @@ public class VladAutomatic : BaseGitPackage
await PrerequisiteHelper
.RunGit(
installDir.Parent ?? "",
+ null,
"clone",
"-b",
downloadOptions.BranchName,
@@ -288,16 +307,12 @@ public class VladAutomatic : BaseGitPackage
public override async Task Update(
InstalledPackage installedPackage,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
bool includePrerelease = false,
Action? onConsoleOutput = null
)
{
- if (installedPackage.Version is null)
- {
- throw new Exception("Version is null");
- }
-
progress?.Report(
new ProgressReport(
-1f,
@@ -308,7 +323,12 @@ public class VladAutomatic : BaseGitPackage
);
await PrerequisiteHelper
- .RunGit(installedPackage.FullPath, "checkout", installedPackage.Version.InstalledBranch)
+ .RunGit(
+ installedPackage.FullPath,
+ onConsoleOutput,
+ "checkout",
+ versionOptions.BranchName
+ )
.ConfigureAwait(false);
var venvRunner = new PyVenvRunner(Path.Combine(installedPackage.FullPath!, "venv"));
@@ -327,7 +347,7 @@ public class VladAutomatic : BaseGitPackage
return new InstalledPackageVersion
{
- InstalledBranch = installedPackage.Version.InstalledBranch,
+ InstalledBranch = versionOptions.BranchName,
InstalledCommitSha = output.Replace(Environment.NewLine, "").Replace("\n", "")
};
}
diff --git a/StabilityMatrix.Core/Models/Packages/VoltaML.cs b/StabilityMatrix.Core/Models/Packages/VoltaML.cs
index b8ff9733..482e9dc9 100644
--- a/StabilityMatrix.Core/Models/Packages/VoltaML.cs
+++ b/StabilityMatrix.Core/Models/Packages/VoltaML.cs
@@ -1,4 +1,5 @@
using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models.Progress;
@@ -8,6 +9,7 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Models.Packages;
+[Singleton(typeof(BasePackage))]
public class VoltaML : BaseGitPackage
{
public override string Name => "voltaML-fast-stable-diffusion";
@@ -24,6 +26,8 @@ public class VoltaML : BaseGitPackage
"https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8"
);
+ public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced;
+
// There are releases but the manager just downloads the latest commit anyways,
// so we'll just limit to commit mode to be more consistent
public override bool ShouldIgnoreReleases => true;
@@ -45,9 +49,20 @@ public class VoltaML : BaseGitPackage
[SharedFolderType.TextualInversion] = new[] { "data/textual-inversion" },
};
+ public override Dictionary>? SharedOutputFolders =>
+ new()
+ {
+ [SharedOutputType.Text2Img] = new[] { "data/outputs/txt2img" },
+ [SharedOutputType.Extras] = new[] { "data/outputs/extra" },
+ [SharedOutputType.Img2Img] = new[] { "data/outputs/img2img" },
+ };
+
+ public override string OutputFolderName => "data/outputs";
+
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
- public override IEnumerable AvailableTorchVersions => new[] { TorchVersion.None };
+ public override IEnumerable AvailableTorchVersions =>
+ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Mps };
public override IEnumerable AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
@@ -136,12 +151,11 @@ public class VoltaML : BaseGitPackage
public override async Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
+ DownloadPackageVersionOptions versionOptions,
IProgress? progress = null,
Action? onConsoleOutput = null
)
{
- await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
-
// Setup venv
progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true));
await using var venvRunner = new PyVenvRunner(Path.Combine(installLocation, "venv"));
diff --git a/StabilityMatrix.Core/Models/Settings/Settings.cs b/StabilityMatrix.Core/Models/Settings/Settings.cs
index cb01f14e..8997f144 100644
--- a/StabilityMatrix.Core/Models/Settings/Settings.cs
+++ b/StabilityMatrix.Core/Models/Settings/Settings.cs
@@ -1,4 +1,5 @@
-using System.Globalization;
+using System.Drawing;
+using System.Globalization;
using System.Text.Json.Serialization;
using Semver;
using StabilityMatrix.Core.Converters.Json;
@@ -70,6 +71,11 @@ public class Settings
///
public bool IsCompletionRemoveUnderscoresEnabled { get; set; } = true;
+ ///
+ /// Format for Inference output image file names
+ ///
+ public string? InferenceOutputImageFileNameFormat { get; set; }
+
///
/// Whether the Inference Image Viewer shows pixel grids at high zoom levels
///
@@ -88,6 +94,9 @@ public class Settings
public HashSet FavoriteModels { get; set; } = new();
+ public Size InferenceImageSize { get; set; } = new(150, 190);
+ public Size OutputsImageSize { get; set; } = new(300, 300);
+
public void RemoveInstalledPackageAndUpdateActive(InstalledPackage package)
{
RemoveInstalledPackageAndUpdateActive(package.Id);
diff --git a/StabilityMatrix.Core/Models/Settings/Size.cs b/StabilityMatrix.Core/Models/Settings/Size.cs
new file mode 100644
index 00000000..a98e29f0
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Settings/Size.cs
@@ -0,0 +1,10 @@
+namespace StabilityMatrix.Core.Models.Settings;
+
+public record struct Size(double Width, double Height)
+{
+ public static Size operator +(Size current, Size other) =>
+ new(current.Width + other.Width, current.Height + other.Height);
+
+ public static Size operator -(Size current, Size other) =>
+ new(current.Width - other.Width, current.Height - other.Height);
+}
diff --git a/StabilityMatrix.Core/Models/SharedOutputType.cs b/StabilityMatrix.Core/Models/SharedOutputType.cs
new file mode 100644
index 00000000..d4dede8a
--- /dev/null
+++ b/StabilityMatrix.Core/Models/SharedOutputType.cs
@@ -0,0 +1,13 @@
+namespace StabilityMatrix.Core.Models;
+
+public enum SharedOutputType
+{
+ All,
+ Text2Img,
+ Img2Img,
+ Extras,
+ Text2ImgGrids,
+ Img2ImgGrids,
+ Saved,
+ Consolidated
+}
diff --git a/StabilityMatrix.Core/Models/TrackedDownload.cs b/StabilityMatrix.Core/Models/TrackedDownload.cs
index 2411b95c..e4017d52 100644
--- a/StabilityMatrix.Core/Models/TrackedDownload.cs
+++ b/StabilityMatrix.Core/Models/TrackedDownload.cs
@@ -13,39 +13,40 @@ namespace StabilityMatrix.Core.Models;
public class TrackedDownload
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
-
+
[JsonIgnore]
private IDownloadService? downloadService;
-
+
[JsonIgnore]
private Task? downloadTask;
-
+
[JsonIgnore]
private CancellationTokenSource? downloadCancellationTokenSource;
-
+
[JsonIgnore]
private CancellationTokenSource? downloadPauseTokenSource;
-
+
[JsonIgnore]
private CancellationTokenSource AggregateCancellationTokenSource =>
CancellationTokenSource.CreateLinkedTokenSource(
downloadCancellationTokenSource?.Token ?? CancellationToken.None,
- downloadPauseTokenSource?.Token ?? CancellationToken.None);
-
+ downloadPauseTokenSource?.Token ?? CancellationToken.None
+ );
+
public required Guid Id { get; init; }
-
+
public required Uri SourceUrl { get; init; }
-
+
public Uri? RedirectedUrl { get; init; }
-
+
public required DirectoryPath DownloadDirectory { get; init; }
-
+
public required string FileName { get; init; }
-
+
public required string TempFileName { get; init; }
-
+
public string? ExpectedHashSha256 { get; set; }
-
+
[JsonIgnore]
[MemberNotNullWhen(true, nameof(ExpectedHashSha256))]
public bool ValidateHash => ExpectedHashSha256 is not null;
@@ -54,22 +55,24 @@ public class TrackedDownload
public ProgressState ProgressState { get; set; } = ProgressState.Inactive;
public List ExtraCleanupFileNames { get; init; } = new();
-
+
// Used for restoring progress on load
public long DownloadedBytes { get; set; }
public long TotalBytes { get; set; }
-
+
///
/// Optional context action to be invoked on completion
///
public IContextAction? ContextAction { get; set; }
-
+
[JsonIgnore]
public Exception? Exception { get; private set; }
-
+
+ private int attempts;
+
#region Events
private WeakEventManager? progressUpdateEventManager;
-
+
public event EventHandler ProgressUpdate
{
add
@@ -79,18 +82,18 @@ public class TrackedDownload
}
remove => progressUpdateEventManager?.RemoveEventHandler(value);
}
-
+
protected void OnProgressUpdate(ProgressReport e)
{
// Update downloaded and total bytes
DownloadedBytes = Convert.ToInt64(e.Current);
TotalBytes = Convert.ToInt64(e.Total);
-
+
progressUpdateEventManager?.RaiseEvent(this, e, nameof(ProgressUpdate));
}
-
+
private WeakEventManager? progressStateChangedEventManager;
-
+
public event EventHandler ProgressStateChanged
{
add
@@ -100,13 +103,13 @@ public class TrackedDownload
}
remove => progressStateChangedEventManager?.RemoveEventHandler(value);
}
-
+
protected void OnProgressStateChanged(ProgressState e)
{
progressStateChangedEventManager?.RaiseEvent(this, e, nameof(ProgressStateChanged));
}
#endregion
-
+
[MemberNotNull(nameof(downloadService))]
private void EnsureDownloadService()
{
@@ -119,42 +122,51 @@ public class TrackedDownload
private async Task StartDownloadTask(long resumeFromByte, CancellationToken cancellationToken)
{
var progress = new Progress(OnProgressUpdate);
-
+
await downloadService!.ResumeDownloadToFileAsync(
SourceUrl.ToString(),
DownloadDirectory.JoinFile(TempFileName),
resumeFromByte,
progress,
- cancellationToken: cancellationToken).ConfigureAwait(false);
-
+ cancellationToken: cancellationToken
+ );
+
// If hash validation is enabled, validate the hash
if (ValidateHash)
{
- OnProgressUpdate(new ProgressReport(0, isIndeterminate: true, type: ProgressType.Hashing));
- var hash = await FileHash.GetSha256Async(DownloadDirectory.JoinFile(TempFileName), progress).ConfigureAwait(false);
+ OnProgressUpdate(
+ new ProgressReport(0, isIndeterminate: true, type: ProgressType.Hashing)
+ );
+ var hash = await FileHash
+ .GetSha256Async(DownloadDirectory.JoinFile(TempFileName), progress)
+ .ConfigureAwait(false);
if (hash != ExpectedHashSha256?.ToLowerInvariant())
{
- throw new Exception($"Hash validation for {FileName} failed, expected {ExpectedHashSha256} but got {hash}");
+ throw new Exception(
+ $"Hash validation for {FileName} failed, expected {ExpectedHashSha256} but got {hash}"
+ );
}
}
}
-
+
public void Start()
{
if (ProgressState != ProgressState.Inactive)
{
- throw new InvalidOperationException($"Download state must be inactive to start, not {ProgressState}");
+ throw new InvalidOperationException(
+ $"Download state must be inactive to start, not {ProgressState}"
+ );
}
Logger.Debug("Starting download {Download}", FileName);
-
+
EnsureDownloadService();
-
+
downloadCancellationTokenSource = new CancellationTokenSource();
downloadPauseTokenSource = new CancellationTokenSource();
-
+
downloadTask = StartDownloadTask(0, AggregateCancellationTokenSource.Token)
.ContinueWith(OnDownloadTaskCompleted);
-
+
ProgressState = ProgressState.Working;
OnProgressStateChanged(ProgressState);
}
@@ -163,27 +175,31 @@ public class TrackedDownload
{
if (ProgressState != ProgressState.Inactive)
{
- Logger.Warn("Attempted to resume download {Download} but it is not paused ({State})", FileName, ProgressState);
+ Logger.Warn(
+ "Attempted to resume download {Download} but it is not paused ({State})",
+ FileName,
+ ProgressState
+ );
}
Logger.Debug("Resuming download {Download}", FileName);
-
+
// Read the temp file to get the current size
var tempSize = 0L;
-
+
var tempFile = DownloadDirectory.JoinFile(TempFileName);
if (tempFile.Exists)
{
tempSize = tempFile.Info.Length;
}
-
+
EnsureDownloadService();
-
+
downloadCancellationTokenSource = new CancellationTokenSource();
downloadPauseTokenSource = new CancellationTokenSource();
-
+
downloadTask = StartDownloadTask(tempSize, AggregateCancellationTokenSource.Token)
.ContinueWith(OnDownloadTaskCompleted);
-
+
ProgressState = ProgressState.Working;
OnProgressStateChanged(ProgressState);
}
@@ -192,22 +208,30 @@ public class TrackedDownload
{
if (ProgressState != ProgressState.Working)
{
- Logger.Warn("Attempted to pause download {Download} but it is not in progress ({State})", FileName, ProgressState);
+ Logger.Warn(
+ "Attempted to pause download {Download} but it is not in progress ({State})",
+ FileName,
+ ProgressState
+ );
return;
}
-
+
Logger.Debug("Pausing download {Download}", FileName);
downloadPauseTokenSource?.Cancel();
}
-
+
public void Cancel()
{
if (ProgressState is not (ProgressState.Working or ProgressState.Inactive))
{
- Logger.Warn("Attempted to cancel download {Download} but it is not in progress ({State})", FileName, ProgressState);
+ Logger.Warn(
+ "Attempted to cancel download {Download} but it is not in progress ({State})",
+ FileName,
+ ProgressState
+ );
return;
}
-
+
Logger.Debug("Cancelling download {Download}", FileName);
// Cancel token if it exists
@@ -219,7 +243,7 @@ public class TrackedDownload
else
{
DoCleanup();
-
+
ProgressState = ProgressState.Cancelled;
OnProgressStateChanged(ProgressState);
}
@@ -238,7 +262,7 @@ public class TrackedDownload
{
Logger.Warn("Failed to delete temp file {TempFile}", TempFileName);
}
-
+
foreach (var extraFile in ExtraCleanupFileNames)
{
try
@@ -251,7 +275,7 @@ public class TrackedDownload
}
}
}
-
+
///
/// Invoked by the task's completion callback
///
@@ -272,7 +296,9 @@ public class TrackedDownload
}
else
{
- throw new InvalidOperationException("Download task was cancelled but neither cancellation token was cancelled.");
+ throw new InvalidOperationException(
+ "Download task was cancelled but neither cancellation token was cancelled."
+ );
}
}
// For faulted
@@ -281,6 +307,23 @@ public class TrackedDownload
// Set the exception
Exception = task.Exception;
+ if (
+ (Exception is IOException || Exception?.InnerException is IOException)
+ && attempts < 3
+ )
+ {
+ attempts++;
+ Logger.Warn(
+ "Download {Download} failed with {Exception}, retrying ({Attempt})",
+ FileName,
+ Exception,
+ attempts
+ );
+ ProgressState = ProgressState.Inactive;
+ Resume();
+ return;
+ }
+
ProgressState = ProgressState.Failed;
}
// Otherwise success
@@ -293,23 +336,23 @@ public class TrackedDownload
if (ProgressState is ProgressState.Failed or ProgressState.Cancelled)
{
DoCleanup();
- }
+ }
else if (ProgressState == ProgressState.Success)
{
// Move the temp file to the final file
DownloadDirectory.JoinFile(TempFileName).MoveTo(DownloadDirectory.JoinFile(FileName));
}
-
+
// For pause, just do nothing
-
+
OnProgressStateChanged(ProgressState);
-
+
// Dispose of the task and cancellation token
downloadTask = null;
downloadCancellationTokenSource = null;
downloadPauseTokenSource = null;
}
-
+
public void SetDownloadService(IDownloadService service)
{
downloadService = service;
diff --git a/StabilityMatrix.Core/Processes/Argument.cs b/StabilityMatrix.Core/Processes/Argument.cs
new file mode 100644
index 00000000..44f2c846
--- /dev/null
+++ b/StabilityMatrix.Core/Processes/Argument.cs
@@ -0,0 +1,6 @@
+using OneOf;
+
+namespace StabilityMatrix.Core.Processes;
+
+[GenerateOneOf]
+public partial class Argument : OneOfBase { }
diff --git a/StabilityMatrix.Core/Processes/ProcessArgs.cs b/StabilityMatrix.Core/Processes/ProcessArgs.cs
new file mode 100644
index 00000000..8fd296b3
--- /dev/null
+++ b/StabilityMatrix.Core/Processes/ProcessArgs.cs
@@ -0,0 +1,73 @@
+using System.Collections;
+using System.Text.RegularExpressions;
+using OneOf;
+
+namespace StabilityMatrix.Core.Processes;
+
+///
+/// Parameter type for command line arguments
+/// Implicitly converts between string and string[],
+/// with no parsing if the input and output types are the same.
+///
+public partial class ProcessArgs : OneOfBase, IEnumerable
+{
+ ///
+ public ProcessArgs(OneOf input)
+ : base(input) { }
+
+ ///
+ /// Whether the argument string contains the given substring,
+ /// or any of the given arguments if the input is an array.
+ ///
+ public bool Contains(string arg) =>
+ Match(str => str.Contains(arg), arr => arr.Any(x => x.Contains(arg)));
+
+ public ProcessArgs Concat(ProcessArgs other) =>
+ Match(
+ str => new ProcessArgs(string.Join(' ', str, other.ToString())),
+ arr => new ProcessArgs(arr.Concat(other.ToArray()).ToArray())
+ );
+
+ public ProcessArgs Prepend(ProcessArgs other) =>
+ Match(
+ str => new ProcessArgs(string.Join(' ', other.ToString(), str)),
+ arr => new ProcessArgs(other.ToArray().Concat(arr).ToArray())
+ );
+
+ ///
+ public IEnumerator GetEnumerator()
+ {
+ return ToArray().AsEnumerable().GetEnumerator();
+ }
+
+ ///
+ public override string ToString()
+ {
+ return Match(str => str, arr => string.Join(' ', arr.Select(ProcessRunner.Quote)));
+ }
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
+
+ public string[] ToArray() =>
+ Match(
+ str => ArgumentsRegex().Matches(str).Select(x => x.Value.Trim('"')).ToArray(),
+ arr => arr
+ );
+
+ // Implicit conversions
+
+ public static implicit operator ProcessArgs(string input) => new(input);
+
+ public static implicit operator ProcessArgs(string[] input) => new(input);
+
+ public static implicit operator string(ProcessArgs input) => input.ToString();
+
+ public static implicit operator string[](ProcessArgs input) => input.ToArray();
+
+ [GeneratedRegex("""[\"].+?[\"]|[^ ]+""", RegexOptions.IgnoreCase)]
+ private static partial Regex ArgumentsRegex();
+}
diff --git a/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs
new file mode 100644
index 00000000..00a0e812
--- /dev/null
+++ b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs
@@ -0,0 +1,78 @@
+using System.Diagnostics;
+using System.Diagnostics.Contracts;
+using OneOf;
+
+namespace StabilityMatrix.Core.Processes;
+
+///
+/// Builder for .
+///
+public record ProcessArgsBuilder
+{
+ protected ProcessArgsBuilder() { }
+
+ public ProcessArgsBuilder(params Argument[] arguments)
+ {
+ Arguments = arguments.ToList();
+ }
+
+ public List Arguments { get; init; } = new();
+
+ private IEnumerable ToStringArgs()
+ {
+ foreach (var argument in Arguments)
+ {
+ if (argument.IsT0)
+ {
+ yield return argument.AsT0;
+ }
+ else
+ {
+ yield return argument.AsT1.Item1;
+ yield return argument.AsT1.Item2;
+ }
+ }
+ }
+
+ ///
+ public override string ToString()
+ {
+ return ToProcessArgs().ToString();
+ }
+
+ public ProcessArgs ToProcessArgs()
+ {
+ return ToStringArgs().ToArray();
+ }
+
+ public static implicit operator ProcessArgs(ProcessArgsBuilder builder) =>
+ builder.ToProcessArgs();
+}
+
+public static class ProcessArgBuilderExtensions
+{
+ [Pure]
+ public static T AddArg(this T builder, Argument argument)
+ where T : ProcessArgsBuilder
+ {
+ return builder with { Arguments = builder.Arguments.Append(argument).ToList() };
+ }
+
+ [Pure]
+ public static T RemoveArgKey(this T builder, string argumentKey)
+ where T : ProcessArgsBuilder
+ {
+ return builder with
+ {
+ Arguments = builder.Arguments
+ .Where(
+ x =>
+ x.Match(
+ stringArg => stringArg != argumentKey,
+ tupleArg => tupleArg.Item1 != argumentKey
+ )
+ )
+ .ToList()
+ };
+ }
+}
diff --git a/StabilityMatrix.Core/Processes/ProcessRunner.cs b/StabilityMatrix.Core/Processes/ProcessRunner.cs
index 94763cb7..3a20cc60 100644
--- a/StabilityMatrix.Core/Processes/ProcessRunner.cs
+++ b/StabilityMatrix.Core/Processes/ProcessRunner.cs
@@ -163,6 +163,55 @@ public static class ProcessRunner
return output;
}
+ public static async Task GetProcessResultAsync(
+ string fileName,
+ ProcessArgs arguments,
+ string? workingDirectory = null,
+ IReadOnlyDictionary? environmentVariables = null
+ )
+ {
+ Logger.Debug($"Starting process '{fileName}' with arguments '{arguments}'");
+
+ var info = new ProcessStartInfo
+ {
+ FileName = fileName,
+ Arguments = arguments,
+ UseShellExecute = false,
+ RedirectStandardOutput = true,
+ RedirectStandardError = true,
+ CreateNoWindow = true
+ };
+
+ if (environmentVariables != null)
+ {
+ foreach (var (key, value) in environmentVariables)
+ {
+ info.EnvironmentVariables[key] = value;
+ }
+ }
+
+ if (workingDirectory != null)
+ {
+ info.WorkingDirectory = workingDirectory;
+ }
+
+ using var process = new Process();
+ process.StartInfo = info;
+ StartTrackedProcess(process);
+
+ var stdout = await process.StandardOutput.ReadToEndAsync().ConfigureAwait(false);
+ var stderr = await process.StandardError.ReadToEndAsync().ConfigureAwait(false);
+
+ await process.WaitForExitAsync().ConfigureAwait(false);
+
+ return new ProcessResult
+ {
+ ExitCode = process.ExitCode,
+ StandardOutput = stdout,
+ StandardError = stderr
+ };
+ }
+
public static Process StartProcess(
string fileName,
string arguments,
diff --git a/StabilityMatrix.Core/Python/PipIndexResult.cs b/StabilityMatrix.Core/Python/PipIndexResult.cs
new file mode 100644
index 00000000..ea9e62f7
--- /dev/null
+++ b/StabilityMatrix.Core/Python/PipIndexResult.cs
@@ -0,0 +1,32 @@
+using System.Collections.Immutable;
+using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Extensions;
+
+namespace StabilityMatrix.Core.Python;
+
+public partial record PipIndexResult
+{
+ public required IReadOnlyList AvailableVersions { get; init; }
+
+ public static PipIndexResult Parse(string output)
+ {
+ var match = AvailableVersionsRegex().Matches(output);
+
+ var versions = output
+ .SplitLines()
+ .Select(line => AvailableVersionsRegex().Match(line))
+ .First(m => m.Success)
+ .Groups["versions"].Value
+ .Split(
+ new[] { ',' },
+ StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries
+ )
+ .ToImmutableArray();
+
+ return new PipIndexResult { AvailableVersions = versions };
+ }
+
+ // Regex, capture the line starting with "Available versions:"
+ [GeneratedRegex(@"^Available versions:\s*(?.*)$")]
+ private static partial Regex AvailableVersionsRegex();
+}
diff --git a/StabilityMatrix.Core/Python/PipInstallArgs.cs b/StabilityMatrix.Core/Python/PipInstallArgs.cs
new file mode 100644
index 00000000..d16aedf7
--- /dev/null
+++ b/StabilityMatrix.Core/Python/PipInstallArgs.cs
@@ -0,0 +1,31 @@
+using StabilityMatrix.Core.Processes;
+
+namespace StabilityMatrix.Core.Python;
+
+public record PipInstallArgs : ProcessArgsBuilder
+{
+ public PipInstallArgs(params Argument[] arguments)
+ : base(arguments) { }
+
+ public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}");
+
+ public PipInstallArgs WithTorchDirectML(string version = "") =>
+ this.AddArg($"torch-directml{version}");
+
+ public PipInstallArgs WithTorchVision(string version = "") =>
+ this.AddArg($"torchvision{version}");
+
+ public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}");
+
+ public PipInstallArgs WithExtraIndex(string indexUrl) =>
+ this.AddArg(("--extra-index-url", indexUrl));
+
+ public PipInstallArgs WithTorchExtraIndex(string index) =>
+ this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}"));
+
+ ///
+ public override string ToString()
+ {
+ return base.ToString();
+ }
+}
diff --git a/StabilityMatrix.Core/Python/PipPackageInfo.cs b/StabilityMatrix.Core/Python/PipPackageInfo.cs
new file mode 100644
index 00000000..eee1af99
--- /dev/null
+++ b/StabilityMatrix.Core/Python/PipPackageInfo.cs
@@ -0,0 +1,7 @@
+namespace StabilityMatrix.Core.Python;
+
+public readonly record struct PipPackageInfo(
+ string Name,
+ string Version,
+ string? EditableProjectLocation = null
+);
diff --git a/StabilityMatrix.Core/Python/PipShowResult.cs b/StabilityMatrix.Core/Python/PipShowResult.cs
new file mode 100644
index 00000000..4bc132fe
--- /dev/null
+++ b/StabilityMatrix.Core/Python/PipShowResult.cs
@@ -0,0 +1,57 @@
+using StabilityMatrix.Core.Extensions;
+
+namespace StabilityMatrix.Core.Python;
+
+public record PipShowResult
+{
+ public required string Name { get; init; }
+
+ public required string Version { get; init; }
+
+ public string? Summary { get; init; }
+
+ public string? HomePage { get; init; }
+
+ public string? Author { get; init; }
+
+ public string? AuthorEmail { get; init; }
+
+ public string? License { get; init; }
+
+ public string? Location { get; init; }
+
+ public List? Requires { get; init; }
+
+ public List? RequiredBy { get; init; }
+
+ public static PipShowResult Parse(string output)
+ {
+ // Decode each line by splitting on first ":" to key and value
+ var lines = output
+ .SplitLines(StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries)
+ .Select(line => line.Split(new[] { ':' }, 2))
+ .Where(split => split.Length == 2)
+ .Select(split => new KeyValuePair(split[0].Trim(), split[1].Trim()))
+ .ToDictionary(pair => pair.Key, pair => pair.Value);
+
+ return new PipShowResult
+ {
+ Name = lines["Name"],
+ Version = lines["Version"],
+ Summary = lines.GetValueOrDefault("Summary"),
+ HomePage = lines.GetValueOrDefault("Home-page"),
+ Author = lines.GetValueOrDefault("Author"),
+ AuthorEmail = lines.GetValueOrDefault("Author-email"),
+ License = lines.GetValueOrDefault("License"),
+ Location = lines.GetValueOrDefault("Location"),
+ Requires = lines
+ .GetValueOrDefault("Requires")
+ ?.Split(new[] { ',' }, StringSplitOptions.TrimEntries)
+ .ToList(),
+ RequiredBy = lines
+ .GetValueOrDefault("Required-by")
+ ?.Split(new[] { ',' }, StringSplitOptions.TrimEntries)
+ .ToList()
+ };
+ }
+}
diff --git a/StabilityMatrix.Core/Python/PyRunner.cs b/StabilityMatrix.Core/Python/PyRunner.cs
index e3e3dac0..ea7268d6 100644
--- a/StabilityMatrix.Core/Python/PyRunner.cs
+++ b/StabilityMatrix.Core/Python/PyRunner.cs
@@ -1,6 +1,7 @@
using System.Diagnostics.CodeAnalysis;
using NLog;
using Python.Runtime;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
@@ -11,42 +12,57 @@ using StabilityMatrix.Core.Python.Interop;
namespace StabilityMatrix.Core.Python;
[SuppressMessage("ReSharper", "NotAccessedPositionalProperty.Global")]
-public record struct PyVersionInfo(int Major, int Minor, int Micro, string ReleaseLevel, int Serial);
+public record struct PyVersionInfo(
+ int Major,
+ int Minor,
+ int Micro,
+ string ReleaseLevel,
+ int Serial
+);
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+[Singleton(typeof(IPyRunner))]
public class PyRunner : IPyRunner
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
-
+
// Set by ISettingsManager.TryFindLibrary()
public static DirectoryPath HomeDir { get; set; } = string.Empty;
-
+
// This is same for all platforms
public const string PythonDirName = "Python310";
-
- public static string PythonDir => Path.Combine(GlobalConfig.LibraryDir, "Assets", PythonDirName);
+
+ public static string PythonDir =>
+ Path.Combine(GlobalConfig.LibraryDir, "Assets", PythonDirName);
///
/// Path to the Python Linked library relative from the Python directory.
///
- public static string RelativePythonDllPath => Compat.Switch(
- (PlatformKind.Windows, "python310.dll"),
- (PlatformKind.Linux, Path.Combine("lib", "libpython3.10.so")),
- (PlatformKind.MacOS, Path.Combine("lib", "libpython3.10.dylib")));
+ public static string RelativePythonDllPath =>
+ Compat.Switch(
+ (PlatformKind.Windows, "python310.dll"),
+ (PlatformKind.Linux, Path.Combine("lib", "libpython3.10.so")),
+ (PlatformKind.MacOS, Path.Combine("lib", "libpython3.10.dylib"))
+ );
public static string PythonDllPath => Path.Combine(PythonDir, RelativePythonDllPath);
- public static string PythonExePath => Compat.Switch(
- (PlatformKind.Windows, Path.Combine(PythonDir, "python.exe")),
- (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "python3")),
- (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "python3")));
- public static string PipExePath => Compat.Switch(
- (PlatformKind.Windows, Path.Combine(PythonDir, "Scripts", "pip.exe")),
- (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "pip3")),
- (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "pip3")));
-
+ public static string PythonExePath =>
+ Compat.Switch(
+ (PlatformKind.Windows, Path.Combine(PythonDir, "python.exe")),
+ (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "python3")),
+ (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "python3"))
+ );
+ public static string PipExePath =>
+ Compat.Switch(
+ (PlatformKind.Windows, Path.Combine(PythonDir, "Scripts", "pip.exe")),
+ (PlatformKind.Linux, Path.Combine(PythonDir, "bin", "pip3")),
+ (PlatformKind.MacOS, Path.Combine(PythonDir, "bin", "pip3"))
+ );
+
public static string GetPipPath => Path.Combine(PythonDir, "get-pip.pyc");
- public static string VenvPath => Path.Combine(PythonDir, "Scripts", "virtualenv" + Compat.ExeExtension);
+ public static string VenvPath =>
+ Path.Combine(PythonDir, "Scripts", "virtualenv" + Compat.ExeExtension);
public static bool PipInstalled => File.Exists(PipExePath);
public static bool VenvInstalled => File.Exists(VenvPath);
@@ -55,7 +71,7 @@ public class PyRunner : IPyRunner
public PyIOStream? StdOutStream;
public PyIOStream? StdErrStream;
-
+
/// $
/// Initializes the Python runtime using the embedded dll.
/// Can be called with no effect after initialization.
@@ -63,10 +79,11 @@ public class PyRunner : IPyRunner
/// Thrown if Python DLL not found.
public async Task Initialize()
{
- if (PythonEngine.IsInitialized) return;
-
+ if (PythonEngine.IsInitialized)
+ return;
+
Logger.Info("Setting PYTHONHOME={PythonDir}", PythonDir.ToRepr());
-
+
// Append Python path to PATH
var newEnvPath = Compat.GetEnvPathWithExtensions(PythonDir);
Logger.Debug("Setting PATH={NewEnvPath}", newEnvPath.ToRepr());
@@ -78,7 +95,7 @@ public class PyRunner : IPyRunner
{
throw new FileNotFoundException("Python linked library not found", PythonDllPath);
}
-
+
Runtime.PythonDLL = PythonDllPath;
PythonEngine.PythonHome = PythonDir;
PythonEngine.Initialize();
@@ -88,12 +105,14 @@ public class PyRunner : IPyRunner
StdOutStream = new PyIOStream();
StdErrStream = new PyIOStream();
await RunInThreadWithLock(() =>
- {
- var sys = Py.Import("sys") as PyModule ??
- throw new NullReferenceException("sys module not found");
- sys.Set("stdout", StdOutStream);
- sys.Set("stderr", StdErrStream);
- }).ConfigureAwait(false);
+ {
+ var sys =
+ Py.Import("sys") as PyModule
+ ?? throw new NullReferenceException("sys module not found");
+ sys.Set("stdout", StdOutStream);
+ sys.Set("stderr", StdErrStream);
+ })
+ .ConfigureAwait(false);
}
///
@@ -129,19 +148,26 @@ public class PyRunner : IPyRunner
/// Time limit for waiting on PyRunning lock.
/// Cancellation token.
/// cancelToken was canceled, or waitTimeout expired.
- public async Task RunInThreadWithLock(Func func, TimeSpan? waitTimeout = null, CancellationToken cancelToken = default)
+ public async Task RunInThreadWithLock(
+ Func func,
+ TimeSpan? waitTimeout = null,
+ CancellationToken cancelToken = default
+ )
{
// Wait to acquire PyRunning lock
await PyRunning.WaitAsync(cancelToken).ConfigureAwait(false);
try
{
- return await Task.Run(() =>
- {
- using (Py.GIL())
+ return await Task.Run(
+ () =>
{
- return func();
- }
- }, cancelToken);
+ using (Py.GIL())
+ {
+ return func();
+ }
+ },
+ cancelToken
+ );
}
finally
{
@@ -156,19 +182,26 @@ public class PyRunner : IPyRunner
/// Time limit for waiting on PyRunning lock.
/// Cancellation token.
/// cancelToken was canceled, or waitTimeout expired.
- public async Task RunInThreadWithLock(Action action, TimeSpan? waitTimeout = null, CancellationToken cancelToken = default)
+ public async Task RunInThreadWithLock(
+ Action action,
+ TimeSpan? waitTimeout = null,
+ CancellationToken cancelToken = default
+ )
{
// Wait to acquire PyRunning lock
await PyRunning.WaitAsync(cancelToken).ConfigureAwait(false);
try
{
- await Task.Run(() =>
- {
- using (Py.GIL())
+ await Task.Run(
+ () =>
{
- action();
- }
- }, cancelToken);
+ using (Py.GIL())
+ {
+ action();
+ }
+ },
+ cancelToken
+ );
}
finally
{
diff --git a/StabilityMatrix.Core/Python/PyVenvRunner.cs b/StabilityMatrix.Core/Python/PyVenvRunner.cs
index f35b66ef..5b3ae3dc 100644
--- a/StabilityMatrix.Core/Python/PyVenvRunner.cs
+++ b/StabilityMatrix.Core/Python/PyVenvRunner.cs
@@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis;
using System.Text;
+using System.Text.Json;
using System.Text.RegularExpressions;
using NLog;
using Salaros.Configuration;
@@ -9,6 +10,7 @@ using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Processes;
+using Yoh.Text.Json.NamingPolicies;
namespace StabilityMatrix.Core.Python;
@@ -19,20 +21,6 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
- private const string TorchPipInstallArgs = "torch==2.0.1 torchvision";
-
- public const string TorchPipInstallArgsCuda =
- $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/cu118";
- public const string TorchPipInstallArgsCpu = TorchPipInstallArgs;
- public const string TorchPipInstallArgsDirectML = "torch-directml";
-
- public const string TorchPipInstallArgsRocm511 =
- $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.1.1";
- public const string TorchPipInstallArgsRocm542 =
- $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.4.2";
- public const string TorchPipInstallArgsRocmNightly56 =
- $"--pre {TorchPipInstallArgs} --index-url https://download.pytorch.org/whl/nightly/rocm5.6";
-
///
/// Relative path to the site-packages folder from the venv root.
/// This is platform specific.
@@ -211,7 +199,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
/// Run a pip install command. Waits for the process to exit.
/// workingDirectory defaults to RootPath.
///
- public async Task PipInstall(string args, Action? outputDataReceived = null)
+ public async Task PipInstall(ProcessArgs args, Action? outputDataReceived = null)
{
if (!File.Exists(PipPath))
{
@@ -231,7 +219,43 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
});
SetPyvenvCfg(PyRunner.PythonDir);
- RunDetached($"-m pip install {args}", outputAction);
+ RunDetached(args.Prepend("-m pip install"), outputAction);
+ await Process.WaitForExitAsync().ConfigureAwait(false);
+
+ // Check return code
+ if (Process.ExitCode != 0)
+ {
+ throw new ProcessException(
+ $"pip install failed with code {Process.ExitCode}: {output.ToString().ToRepr()}"
+ );
+ }
+ }
+
+ ///
+ /// Run a pip uninstall command. Waits for the process to exit.
+ /// workingDirectory defaults to RootPath.
+ ///
+ public async Task PipUninstall(string args, Action? outputDataReceived = null)
+ {
+ if (!File.Exists(PipPath))
+ {
+ throw new FileNotFoundException("pip not found", PipPath);
+ }
+
+ // Record output for errors
+ var output = new StringBuilder();
+
+ var outputAction = new Action(s =>
+ {
+ Logger.Debug($"Pip output: {s.Text}");
+ // Record to output
+ output.Append(s.Text);
+ // Forward to callback
+ outputDataReceived?.Invoke(s);
+ });
+
+ SetPyvenvCfg(PyRunner.PythonDir);
+ RunDetached($"-m pip uninstall -y {args}", outputAction);
await Process.WaitForExitAsync().ConfigureAwait(false);
// Check return code
@@ -270,6 +294,149 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
await PipInstall(pipArgs, outputDataReceived).ConfigureAwait(false);
}
+ ///
+ /// Run a pip list command, return results as PipPackageInfo objects.
+ ///
+ public async Task> PipList()
+ {
+ if (!File.Exists(PipPath))
+ {
+ throw new FileNotFoundException("pip not found", PipPath);
+ }
+
+ SetPyvenvCfg(PyRunner.PythonDir);
+
+ var result = await ProcessRunner
+ .GetProcessResultAsync(
+ PythonPath,
+ "-m pip list --format=json",
+ WorkingDirectory?.FullPath,
+ EnvironmentVariables
+ )
+ .ConfigureAwait(false);
+
+ // Check return code
+ if (result.ExitCode != 0)
+ {
+ throw new ProcessException(
+ $"pip list failed with code {result.ExitCode}: {result.StandardOutput}, {result.StandardError}"
+ );
+ }
+
+ // Use only first line, since there might be pip update messages later
+ if (
+ result.StandardOutput
+ ?.SplitLines(StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries)
+ .FirstOrDefault()
+ is not { } firstLine
+ )
+ {
+ return new List();
+ }
+
+ return JsonSerializer.Deserialize>(
+ firstLine,
+ new JsonSerializerOptions
+ {
+ PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower
+ }
+ ) ?? new List();
+ }
+
+ ///
+ /// Run a pip show command, return results as PipPackageInfo objects.
+ ///
+ public async Task PipShow(string packageName)
+ {
+ if (!File.Exists(PipPath))
+ {
+ throw new FileNotFoundException("pip not found", PipPath);
+ }
+
+ SetPyvenvCfg(PyRunner.PythonDir);
+
+ var result = await ProcessRunner
+ .GetProcessResultAsync(
+ PythonPath,
+ new[] { "-m", "pip", "show", packageName },
+ WorkingDirectory?.FullPath,
+ EnvironmentVariables
+ )
+ .ConfigureAwait(false);
+
+ // Check return code
+ if (result.ExitCode != 0)
+ {
+ throw new ProcessException(
+ $"pip show failed with code {result.ExitCode}: {result.StandardOutput}, {result.StandardError}"
+ );
+ }
+
+ if (result.StandardOutput!.StartsWith("WARNING: Package(s) not found:"))
+ {
+ return null;
+ }
+
+ return PipShowResult.Parse(result.StandardOutput);
+ }
+
+ ///
+ /// Run a pip index command, return result as PipIndexResult.
+ ///
+ public async Task PipIndex(string packageName, string? indexUrl = null)
+ {
+ if (!File.Exists(PipPath))
+ {
+ throw new FileNotFoundException("pip not found", PipPath);
+ }
+
+ SetPyvenvCfg(PyRunner.PythonDir);
+
+ var args = new ProcessArgsBuilder(
+ "-m",
+ "pip",
+ "index",
+ "versions",
+ packageName,
+ "--no-color",
+ "--disable-pip-version-check"
+ );
+
+ if (indexUrl is not null)
+ {
+ args = args.AddArg(("--index-url", indexUrl));
+ }
+
+ var result = await ProcessRunner
+ .GetProcessResultAsync(
+ PythonPath,
+ args,
+ WorkingDirectory?.FullPath,
+ EnvironmentVariables
+ )
+ .ConfigureAwait(false);
+
+ // Check return code
+ if (result.ExitCode != 0)
+ {
+ throw new ProcessException(
+ $"pip index failed with code {result.ExitCode}: {result.StandardOutput}, {result.StandardError}"
+ );
+ }
+
+ if (
+ string.IsNullOrEmpty(result.StandardOutput)
+ || result.StandardOutput!
+ .SplitLines()
+ .Any(l => l.StartsWith("ERROR: No matching distribution found"))
+ )
+ {
+ return null;
+ }
+
+ return PipIndexResult.Parse(result.StandardOutput);
+ }
+
///
/// Run a custom install command. Waits for the process to exit.
/// workingDirectory defaults to RootPath.
@@ -308,7 +475,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
/// Run a command using the venv Python executable and return the result.
///
/// Arguments to pass to the Python executable.
- public async Task Run(string arguments)
+ public async Task Run(ProcessArgs arguments)
{
// Record output for errors
var output = new StringBuilder();
@@ -340,12 +507,14 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
[MemberNotNull(nameof(Process))]
public void RunDetached(
- string arguments,
+ ProcessArgs args,
Action? outputDataReceived,
Action? onExit = null,
bool unbuffered = true
)
{
+ var arguments = args.ToString();
+
if (!PythonPath.Exists)
{
throw new FileNotFoundException("Venv python not found", PythonPath);
@@ -354,10 +523,10 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
Logger.Info(
"Launching venv process [{PythonPath}] "
- + "in working directory [{WorkingDirectory}] with args {arguments.ToRepr()}",
+ + "in working directory [{WorkingDirectory}] with args {Arguments}",
PythonPath,
WorkingDirectory,
- arguments.ToRepr()
+ arguments
);
var filteredOutput =
@@ -380,7 +549,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
}
// Disable pip caching - uses significant memory for large packages like torch
- env["PIP_NO_CACHE_DIR"] = "true";
+ // env["PIP_NO_CACHE_DIR"] = "true";
// On windows, add portable git to PATH and binary as GIT
if (Compat.IsWindows)
diff --git a/StabilityMatrix.Core/Services/DownloadService.cs b/StabilityMatrix.Core/Services/DownloadService.cs
index 7e0be15a..91dbb573 100644
--- a/StabilityMatrix.Core/Services/DownloadService.cs
+++ b/StabilityMatrix.Core/Services/DownloadService.cs
@@ -1,10 +1,12 @@
using System.Net.Http.Headers;
using Microsoft.Extensions.Logging;
using Polly.Contrib.WaitAndRetry;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Progress;
namespace StabilityMatrix.Core.Services;
+[Singleton(typeof(IDownloadService))]
public class DownloadService : IDownloadService
{
private readonly ILogger logger;
diff --git a/StabilityMatrix.Core/Services/IImageIndexService.cs b/StabilityMatrix.Core/Services/IImageIndexService.cs
index 5962df87..66430063 100644
--- a/StabilityMatrix.Core/Services/IImageIndexService.cs
+++ b/StabilityMatrix.Core/Services/IImageIndexService.cs
@@ -9,11 +9,6 @@ public interface IImageIndexService
{
IndexCollection InferenceImages { get; }
- ///
- /// Gets a list of local images that start with the given path prefix
- ///
- Task> GetLocalImagesByPrefix(string pathPrefix);
-
///
/// Refresh index for all collections
///
@@ -25,9 +20,4 @@ public interface IImageIndexService
/// Refreshes the index of local images in the background
///
void BackgroundRefreshIndex();
-
- ///
- /// Removes a local image from the database
- ///
- Task RemoveImage(LocalImageFile imageFile);
}
diff --git a/StabilityMatrix.Core/Services/ISettingsManager.cs b/StabilityMatrix.Core/Services/ISettingsManager.cs
index 462469ae..4b8a92d3 100644
--- a/StabilityMatrix.Core/Services/ISettingsManager.cs
+++ b/StabilityMatrix.Core/Services/ISettingsManager.cs
@@ -21,6 +21,7 @@ public interface ISettingsManager
List PackageInstallsInProgress { get; set; }
Settings Settings { get; }
+ DirectoryPath ConsolidatedImagesDirectory { get; }
///
/// Event fired when the library directory is changed
@@ -69,7 +70,8 @@ public interface ISettingsManager
T source,
Expression> sourceProperty,
Expression> settingsProperty,
- bool setInitial = false
+ bool setInitial = false,
+ TimeSpan? delay = null
)
where T : INotifyPropertyChanged;
diff --git a/StabilityMatrix.Core/Services/ImageIndexService.cs b/StabilityMatrix.Core/Services/ImageIndexService.cs
index c7ee7739..3b6764b3 100644
--- a/StabilityMatrix.Core/Services/ImageIndexService.cs
+++ b/StabilityMatrix.Core/Services/ImageIndexService.cs
@@ -1,11 +1,10 @@
using System.Collections.Concurrent;
using System.Diagnostics;
-using System.Text.Json;
using AsyncAwaitBestPractices;
using DynamicData;
-using DynamicData.Binding;
using Microsoft.Extensions.Logging;
-using StabilityMatrix.Core.Database;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
@@ -13,49 +12,34 @@ using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Core.Services;
+[Singleton(typeof(IImageIndexService))]
public class ImageIndexService : IImageIndexService
{
private readonly ILogger logger;
- private readonly ILiteDbContext liteDbContext;
private readonly ISettingsManager settingsManager;
///
public IndexCollection InferenceImages { get; }
- public ImageIndexService(
- ILogger logger,
- ILiteDbContext liteDbContext,
- ISettingsManager settingsManager
- )
+ public ImageIndexService(ILogger logger, ISettingsManager settingsManager)
{
this.logger = logger;
- this.liteDbContext = liteDbContext;
this.settingsManager = settingsManager;
InferenceImages = new IndexCollection(
this,
- file => file.RelativePath
+ file => file.AbsolutePath
)
{
- RelativePath = "inference"
+ RelativePath = "Inference"
};
EventManager.Instance.ImageFileAdded += OnImageFileAdded;
}
- ///
- public async Task> GetLocalImagesByPrefix(string pathPrefix)
- {
- return await liteDbContext.LocalImageFiles
- .Query()
- .Where(imageFile => imageFile.RelativePath.StartsWith(pathPrefix))
- .ToArrayAsync()
- .ConfigureAwait(false);
- }
-
- public async Task RefreshIndexForAllCollections()
+ public Task RefreshIndexForAllCollections()
{
- await RefreshIndex(InferenceImages).ConfigureAwait(false);
+ return RefreshIndex(InferenceImages);
}
public async Task RefreshIndex(IndexCollection indexCollection)
@@ -72,16 +56,17 @@ public class ImageIndexService : IImageIndexService
// Start
var stopwatch = Stopwatch.StartNew();
- logger.LogInformation("Refreshing images index at {ImagesDir}...", imagesDir);
+ logger.LogInformation("Refreshing images index at {SearchDir}...", searchDir);
var toAdd = new ConcurrentBag();
await Task.Run(() =>
{
- var files = imagesDir.Info
+ var files = searchDir
.EnumerateFiles("*.*", SearchOption.AllDirectories)
- .Where(info => LocalImageFile.SupportedImageExtensions.Contains(info.Extension))
- .Select(info => new FilePath(info));
+ .Where(
+ file => LocalImageFile.SupportedImageExtensions.Contains(file.Extension)
+ );
Parallel.ForEach(
files,
@@ -95,7 +80,7 @@ public class ImageIndexService : IImageIndexService
var indexElapsed = stopwatch.Elapsed;
- indexCollection.ItemsSource.EditDiff(toAdd, LocalImageFile.Comparer);
+ indexCollection.ItemsSource.EditDiff(toAdd);
// End
stopwatch.Stop();
@@ -120,110 +105,9 @@ public class ImageIndexService : IImageIndexService
}
}
- /*public async Task RefreshIndex(IndexCollection indexCollection)
- {
- var imagesDir = settingsManager.ImagesDirectory;
- if (!imagesDir.Exists)
- {
- return;
- }
-
- // Start
- var stopwatch = Stopwatch.StartNew();
- logger.LogInformation("Refreshing images index...");
-
- using var db = await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
-
- var localImageFiles = db.GetCollection("LocalImageFiles")!;
-
- await localImageFiles.DeleteAllAsync().ConfigureAwait(false);
-
- // Record start of actual indexing
- var indexStart = stopwatch.Elapsed;
-
- var added = 0;
-
- foreach (
- var file in imagesDir.Info
- .EnumerateFiles("*.*", SearchOption.AllDirectories)
- .Where(info => LocalImageFile.SupportedImageExtensions.Contains(info.Extension))
- .Select(info => new FilePath(info))
- )
- {
- var relativePath = Path.GetRelativePath(imagesDir, file);
-
- // Skip if not in sub-path
- if (!string.IsNullOrEmpty(subPath) && !relativePath.StartsWith(subPath))
- {
- continue;
- }
-
- // TODO: Support other types
- const LocalImageFileType imageType =
- LocalImageFileType.Inference | LocalImageFileType.TextToImage;
-
- // Get metadata
- using var reader = new BinaryReader(new FileStream(file.FullPath, FileMode.Open));
- var metadata = ImageMetadata.ReadTextChunk(reader, "parameters-json");
- GenerationParameters? genParams = null;
-
- if (!string.IsNullOrWhiteSpace(metadata))
- {
- genParams = JsonSerializer.Deserialize(metadata);
- }
- else
- {
- metadata = ImageMetadata.ReadTextChunk(reader, "parameters");
- if (!string.IsNullOrWhiteSpace(metadata))
- {
- GenerationParameters.TryParse(metadata, out genParams);
- }
- }
-
- var localImage = new LocalImageFile
- {
- RelativePath = relativePath,
- ImageType = imageType,
- CreatedAt = file.Info.CreationTimeUtc,
- LastModifiedAt = file.Info.LastWriteTimeUtc,
- GenerationParameters = genParams
- };
-
- // Insert into database
- await localImageFiles.InsertAsync(localImage).ConfigureAwait(false);
-
- added++;
- }
- // Record end of actual indexing
- var indexEnd = stopwatch.Elapsed;
-
- await db.CommitAsync().ConfigureAwait(false);
-
- // End
- stopwatch.Stop();
- var indexDuration = indexEnd - indexStart;
- var dbDuration = stopwatch.Elapsed - indexDuration;
-
- logger.LogInformation(
- "Image index updated for {Prefix} with {Entries} files, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)",
- subPath,
- added,
- indexDuration.TotalMilliseconds,
- dbDuration.TotalMilliseconds
- );
- }*/
-
///
public void BackgroundRefreshIndex()
{
RefreshIndexForAllCollections().SafeFireAndForget();
}
-
- ///
- public async Task RemoveImage(LocalImageFile imageFile)
- {
- await liteDbContext.LocalImageFiles
- .DeleteAsync(imageFile.RelativePath)
- .ConfigureAwait(false);
- }
}
diff --git a/StabilityMatrix.Core/Services/ModelIndexService.cs b/StabilityMatrix.Core/Services/ModelIndexService.cs
index a70dff68..dec7eeb9 100644
--- a/StabilityMatrix.Core/Services/ModelIndexService.cs
+++ b/StabilityMatrix.Core/Services/ModelIndexService.cs
@@ -1,6 +1,7 @@
using System.Diagnostics;
using AsyncAwaitBestPractices;
using Microsoft.Extensions.Logging;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
@@ -10,6 +11,7 @@ using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Core.Services;
+[Singleton(typeof(IModelIndexService))]
public class ModelIndexService : IModelIndexService
{
private readonly ILogger logger;
diff --git a/StabilityMatrix.Core/Services/SettingsManager.cs b/StabilityMatrix.Core/Services/SettingsManager.cs
index 30695720..47efd00c 100644
--- a/StabilityMatrix.Core/Services/SettingsManager.cs
+++ b/StabilityMatrix.Core/Services/SettingsManager.cs
@@ -7,6 +7,7 @@ using System.Text.Json.Serialization;
using AsyncAwaitBestPractices;
using NLog;
using Refit;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
@@ -15,6 +16,7 @@ using StabilityMatrix.Core.Python;
namespace StabilityMatrix.Core.Services;
+[Singleton(typeof(ISettingsManager))]
public class SettingsManager : ISettingsManager
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -69,6 +71,7 @@ public class SettingsManager : ISettingsManager
public DirectoryPath ImagesDirectory => new(LibraryDir, "Images");
public DirectoryPath ImagesInferenceDirectory => ImagesDirectory.JoinDir("Inference");
+ public DirectoryPath ConsolidatedImagesDirectory => ImagesDirectory.JoinDir("Consolidated");
public Settings Settings { get; private set; } = new();
@@ -163,7 +166,8 @@ public class SettingsManager : ISettingsManager
T source,
Expression> sourceProperty,
Expression> settingsProperty,
- bool setInitial = false
+ bool setInitial = false,
+ TimeSpan? delay = null
)
where T : INotifyPropertyChanged
{
@@ -215,7 +219,14 @@ public class SettingsManager : ISettingsManager
if (IsLibraryDirSet)
{
- SaveSettingsAsync().SafeFireAndForget();
+ if (delay != null)
+ {
+ SaveSettingsDelayed(delay.Value).SafeFireAndForget();
+ }
+ else
+ {
+ SaveSettingsAsync().SafeFireAndForget();
+ }
}
else
{
@@ -654,4 +665,37 @@ public class SettingsManager : ISettingsManager
{
return Task.Run(SaveSettings);
}
+
+ private CancellationTokenSource? delayedSaveCts;
+
+ private Task SaveSettingsDelayed(TimeSpan delay)
+ {
+ var cts = new CancellationTokenSource();
+
+ var oldCancellationToken = Interlocked.Exchange(ref delayedSaveCts, cts);
+
+ try
+ {
+ oldCancellationToken?.Cancel();
+ }
+ catch (ObjectDisposedException) { }
+
+ return Task.Run(
+ async () =>
+ {
+ try
+ {
+ await Task.Delay(delay, cts.Token);
+
+ await SaveSettingsAsync();
+ }
+ catch (TaskCanceledException) { }
+ finally
+ {
+ cts.Dispose();
+ }
+ },
+ CancellationToken.None
+ );
+ }
}
diff --git a/StabilityMatrix.Core/StabilityMatrix.Core.csproj b/StabilityMatrix.Core/StabilityMatrix.Core.csproj
index 1f95863e..70900a96 100644
--- a/StabilityMatrix.Core/StabilityMatrix.Core.csproj
+++ b/StabilityMatrix.Core/StabilityMatrix.Core.csproj
@@ -15,34 +15,36 @@
-
+
-
+
-
-
-
-
-
+
+
+
+
+
-
-
+
+
-
+
+
+
-
+
-
-
-
-
+
+
+
+
diff --git a/StabilityMatrix.Core/Updater/UpdateHelper.cs b/StabilityMatrix.Core/Updater/UpdateHelper.cs
index 401736c1..325af086 100644
--- a/StabilityMatrix.Core/Updater/UpdateHelper.cs
+++ b/StabilityMatrix.Core/Updater/UpdateHelper.cs
@@ -2,6 +2,7 @@
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
+using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.Configs;
@@ -12,6 +13,7 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Updater;
+[Singleton(typeof(IUpdateHelper))]
public class UpdateHelper : IUpdateHelper
{
private readonly ILogger logger;
diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs
new file mode 100644
index 00000000..5905aca0
--- /dev/null
+++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs
@@ -0,0 +1,28 @@
+using System.ComponentModel.DataAnnotations;
+using StabilityMatrix.Avalonia.Models.Inference;
+
+namespace StabilityMatrix.Tests.Avalonia;
+
+[TestClass]
+public class FileNameFormatProviderTests
+{
+ [TestMethod]
+ public void TestFileNameFormatProviderValidate_Valid_ShouldNotThrow()
+ {
+ var provider = new FileNameFormatProvider();
+
+ var result = provider.Validate("{date}_{time}-{model_name}-{seed}");
+ Assert.AreEqual(ValidationResult.Success, result);
+ }
+
+ [TestMethod]
+ public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow()
+ {
+ var provider = new FileNameFormatProvider();
+
+ var result = provider.Validate("{date}_{time}-{model_name}-{seed}-{invalid}");
+ Assert.AreNotEqual(ValidationResult.Success, result);
+
+ Assert.AreEqual("Unknown variable 'invalid'", result.ErrorMessage);
+ }
+}
diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs
new file mode 100644
index 00000000..0da1eb84
--- /dev/null
+++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatTests.cs
@@ -0,0 +1,24 @@
+using StabilityMatrix.Avalonia.Models;
+using StabilityMatrix.Avalonia.Models.Inference;
+using StabilityMatrix.Core.Models;
+
+namespace StabilityMatrix.Tests.Avalonia;
+
+[TestClass]
+public class FileNameFormatTests
+{
+ [TestMethod]
+ public void TestFileNameFormatParse()
+ {
+ var provider = new FileNameFormatProvider
+ {
+ GenerationParameters = new GenerationParameters { Seed = 123 },
+ ProjectName = "uwu",
+ ProjectType = InferenceProjectType.TextToImage,
+ };
+
+ var format = FileNameFormat.Parse("{project_type} - {project_name} ({seed})", provider);
+
+ Assert.AreEqual("TextToImage - uwu (123)", format.GetFileName());
+ }
+}
diff --git a/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs
new file mode 100644
index 00000000..9bb09eab
--- /dev/null
+++ b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs
@@ -0,0 +1,61 @@
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+
+namespace StabilityMatrix.Tests.Core;
+
+[TestClass]
+public class PipInstallArgsTests
+{
+ [TestMethod]
+ public void TestGetTorch()
+ {
+ // Arrange
+ const string version = "==2.1.0";
+
+ // Act
+ var args = new PipInstallArgs().WithTorch(version).ToProcessArgs().ToString();
+
+ // Assert
+ Assert.AreEqual("torch==2.1.0", args);
+ }
+
+ [TestMethod]
+ public void TestGetTorchWithExtraIndex()
+ {
+ // Arrange
+ const string version = ">=2.0.0";
+ const string index = "cu118";
+
+ // Act
+ var args = new PipInstallArgs()
+ .WithTorch(version)
+ .WithTorchVision()
+ .WithTorchExtraIndex(index)
+ .ToProcessArgs()
+ .ToString();
+
+ // Assert
+ Assert.AreEqual(
+ "torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118",
+ args
+ );
+ }
+
+ [TestMethod]
+ public void TestGetTorchWithMoreStuff()
+ {
+ // Act
+ var args = new PipInstallArgs()
+ .AddArg("--pre")
+ .WithTorch("~=2.0.0")
+ .WithTorchVision()
+ .WithTorchExtraIndex("nightly/cpu")
+ .ToString();
+
+ // Assert
+ Assert.AreEqual(
+ "--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu",
+ args
+ );
+ }
+}
diff --git a/StabilityMatrix.Tests/Models/GenerationParametersTests.cs b/StabilityMatrix.Tests/Models/GenerationParametersTests.cs
new file mode 100644
index 00000000..d22caf8c
--- /dev/null
+++ b/StabilityMatrix.Tests/Models/GenerationParametersTests.cs
@@ -0,0 +1,70 @@
+using StabilityMatrix.Core.Models;
+
+namespace StabilityMatrix.Tests.Models;
+
+[TestClass]
+public class GenerationParametersTests
+{
+ [TestMethod]
+ public void TestParse()
+ {
+ const string data = """
+ test123
+ Negative prompt: test, easy negative
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3589107295, Size: 1024x1028, Model hash: 9aa0c3e54d, Model: nightvisionXL_v0770_BakedVAE, VAE hash: 235745af8d, VAE: sdxl_vae.safetensors, Style Selector Enabled: True, Style Selector Randomize: False, Style Selector Style: base, Version: 1.6.0
+ """;
+
+ Assert.IsTrue(GenerationParameters.TryParse(data, out var result));
+
+ Assert.AreEqual("test123", result.PositivePrompt);
+ Assert.AreEqual("test, easy negative", result.NegativePrompt);
+ Assert.AreEqual(20, result.Steps);
+ Assert.AreEqual("Euler a", result.Sampler);
+ Assert.AreEqual(7, result.CfgScale);
+ Assert.AreEqual(3589107295, result.Seed);
+ Assert.AreEqual(1024, result.Width);
+ Assert.AreEqual(1028, result.Height);
+ Assert.AreEqual("9aa0c3e54d", result.ModelHash);
+ Assert.AreEqual("nightvisionXL_v0770_BakedVAE", result.ModelName);
+ }
+
+ [TestMethod]
+ public void TestParse_NoNegative()
+ {
+ const string data = """
+ test123
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3589107295, Size: 1024x1028, Model hash: 9aa0c3e54d, Model: nightvisionXL_v0770_BakedVAE, VAE hash: 235745af8d, VAE: sdxl_vae.safetensors, Style Selector Enabled: True, Style Selector Randomize: False, Style Selector Style: base, Version: 1.6.0
+ """;
+
+ Assert.IsTrue(GenerationParameters.TryParse(data, out var result));
+
+ Assert.AreEqual("test123", result.PositivePrompt);
+ Assert.IsNull(result.NegativePrompt);
+ Assert.AreEqual(20, result.Steps);
+ Assert.AreEqual("Euler a", result.Sampler);
+ Assert.AreEqual(7, result.CfgScale);
+ Assert.AreEqual(3589107295, result.Seed);
+ Assert.AreEqual(1024, result.Width);
+ Assert.AreEqual(1028, result.Height);
+ Assert.AreEqual("9aa0c3e54d", result.ModelHash);
+ Assert.AreEqual("nightvisionXL_v0770_BakedVAE", result.ModelName);
+ }
+
+ [TestMethod]
+ public void TestParseLineFields()
+ {
+ const string lastLine =
+ @"Steps: 30, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 2216407431, Size: 640x896, Model hash: eb2h052f91, Model: anime_v1";
+
+ var fields = GenerationParameters.ParseLine(lastLine);
+
+ Assert.AreEqual(7, fields.Count);
+ Assert.AreEqual("30", fields["Steps"]);
+ Assert.AreEqual("DPM++ 2M Karras", fields["Sampler"]);
+ Assert.AreEqual("7", fields["CFG scale"]);
+ Assert.AreEqual("2216407431", fields["Seed"]);
+ Assert.AreEqual("640x896", fields["Size"]);
+ Assert.AreEqual("eb2h052f91", fields["Model hash"]);
+ Assert.AreEqual("anime_v1", fields["Model"]);
+ }
+}
diff --git a/StabilityMatrix.Tests/Models/ProcessArgsTests.cs b/StabilityMatrix.Tests/Models/ProcessArgsTests.cs
new file mode 100644
index 00000000..ae3281d9
--- /dev/null
+++ b/StabilityMatrix.Tests/Models/ProcessArgsTests.cs
@@ -0,0 +1,43 @@
+using StabilityMatrix.Core.Processes;
+
+namespace StabilityMatrix.Tests.Models;
+
+[TestClass]
+public class ProcessArgsTests
+{
+ [DataTestMethod]
+ [DataRow("pip", new[] { "pip" })]
+ [DataRow("pip install torch", new[] { "pip", "install", "torch" })]
+ [DataRow(
+ "pip install -r \"file spaces/here\"",
+ new[] { "pip", "install", "-r", "file spaces/here" }
+ )]
+ [DataRow(
+ "pip install -r \"file spaces\\here\"",
+ new[] { "pip", "install", "-r", "file spaces\\here" }
+ )]
+ public void TestStringToArray(string input, string[] expected)
+ {
+ ProcessArgs args = input;
+ string[] result = args;
+ CollectionAssert.AreEqual(expected, result);
+ }
+
+ [DataTestMethod]
+ [DataRow(new[] { "pip" }, "pip")]
+ [DataRow(new[] { "pip", "install", "torch" }, "pip install torch")]
+ [DataRow(
+ new[] { "pip", "install", "-r", "file spaces/here" },
+ "pip install -r \"file spaces/here\""
+ )]
+ [DataRow(
+ new[] { "pip", "install", "-r", "file spaces\\here" },
+ "pip install -r \"file spaces\\here\""
+ )]
+ public void TestArrayToString(string[] input, string expected)
+ {
+ ProcessArgs args = input;
+ string result = args;
+ Assert.AreEqual(expected, result);
+ }
+}
diff --git a/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj b/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj
index e82a92a8..f604bbea 100644
--- a/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj
+++ b/StabilityMatrix.Tests/StabilityMatrix.Tests.csproj
@@ -11,18 +11,18 @@
-
+
-
+
-
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
-
+