diff --git a/CHANGELOG.md b/CHANGELOG.md
index 764ee032..c593229e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,11 +5,37 @@ All notable changes to Stability Matrix will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2.0.0.html).
+## v2.3.0
+### Added
+- New installable Package - [Fooocus](https://github.com/lllyasviel/Fooocus)
+- Added "Select New Data Directory" button to Settings
+- Added "Skip to First/Last Page" buttons to the Model Browser
+- Added VAE as a checkpoint category in the Model Browser
+- Pause/Resume/Cancel buttons on downloads popup. Paused downloads persists and may be resumed after restarting the app
+- Unknown Package installs in the Package directory will now show up with a button to import them
+### Fixed
+- Fixed issue where model version wouldn't be selected in the "All Versions" section of the Model Browser
+- Improved Checkpoints page indexing performance
+- Fixed issue where Checkpoints page may not show all checkpoints after clearing search filter
+- Fixed issue where Checkpoints page may show incorrect checkpoints for the given filter after changing pages
+- Fixed issue where Open Web UI button would try to load 0.0.0.0 addresses
+- Fixed Dictionary error when launch arguments saved with duplicate arguments
+- Fixed Launch arguments search not working
+### Changed
+- Changed update method for SD.Next to use the built-in upgrade functionality
+- Model Browser navigation buttons are no longer disabled while changing pages
+
+## v2.2.1
+### Fixed
+- Fixed SD.Next shared folders config not working with new config format, reverted to Junctions / Symlinks
+
## v2.2.1
+
### Fixed
- Fixed SD.Next shared folders config not working with new config format, reverted to Junctions / Symlinks
## v2.2.0
+
### Added
- Added option to search by Base Model in the Model Browser
- Animated page transitions
diff --git a/README.md b/README.md
index 5c2c9897..94eeac30 100644
--- a/README.md
+++ b/README.md
@@ -13,23 +13,27 @@
[sdnext]: https://github.com/vladmandic/automatic
[voltaml]: https://github.com/VoltaML/voltaML-fast-stable-diffusion
[invokeai]: https://github.com/invoke-ai/InvokeAI
+[fooocus]: https://github.com/lllyasviel/Fooocus
[civitai]: https://civitai.com/
Multi-Platform Package Manager for Stable Diffusion
### 🖱️ One click install and update for Stable Diffusion Web UI Packages
-- Supports [Automatic 1111][auto1111], [Comfy UI][comfy], [SD.Next (Vladmandic)][sdnext], [VoltaML][voltaml], [InvokeAI][invokeai]
+- Supports [Automatic 1111][auto1111], [Comfy UI][comfy], [SD.Next (Vladmandic)][sdnext], [VoltaML][voltaml], [InvokeAI][invokeai], and [Fooocus][fooocus]
- Embedded Git and Python dependencies, with no need for either to be globally installed
-- Fully Portable, move Stability Matrix's Data Directory to a new drive or computer at any time
+- Fully portable - move Stability Matrix's Data Directory to a new drive or computer at any time
+
### 🚀 Launcher with syntax highlighted terminal emulator, routed GUI input prompts
- Launch arguments editor with predefined or custom options for each Package install
-- Package environment variables
+- Configurable Environment Variables
+
### 🗃️ Checkpoint Manager, configured to be shared by all Package installs
- Option to find CivitAI metadata and preview thumbnails for new local imports
+
### ☁️ Model Browser to import from [CivitAI][civitai]
- Automatically imports to the associated model folder depending on the model type
-- Also downloads relavent metadata files and preview image
+- Downloads relevant metadata files and preview image
![header](https://github.com/LykosAI/StabilityMatrix/assets/13956642/a9c5f925-8561-49ba-855b-1b7bf57d7c0d)
@@ -48,17 +52,17 @@ Multi-Platform Package Manager for Stable Diffusion
### Model browser powered by [Civit AI][civitai]
- Downloads new models, automatically uses the appropriate shared model directory
-- Available immediately to all installed packages
+- Pause and resume downloads, even after closing the app
### Shared model directory for all your packages
-
- Import local models by simple drag and drop
+- Option to find CivitAI metadata and preview thumbnails for new local imports
- Toggle visibility of categories like LoRA, VAE, CLIP, etc.
-- For models imported from Civit AI, shows additional information like version, fp precision, and preview thumbnail on hover
+
diff --git a/StabilityMatrix.Avalonia/App.axaml b/StabilityMatrix.Avalonia/App.axaml
index 7d74ab58..3977db27 100644
--- a/StabilityMatrix.Avalonia/App.axaml
+++ b/StabilityMatrix.Avalonia/App.axaml
@@ -28,5 +28,6 @@
+
diff --git a/StabilityMatrix.Avalonia/App.axaml.cs b/StabilityMatrix.Avalonia/App.axaml.cs
index 382f7274..78baebba 100644
--- a/StabilityMatrix.Avalonia/App.axaml.cs
+++ b/StabilityMatrix.Avalonia/App.axaml.cs
@@ -34,11 +34,13 @@ using Sentry;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.DesignData;
using StabilityMatrix.Avalonia.Helpers;
+using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
+using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.ViewModels.PackageManager;
using StabilityMatrix.Avalonia.Views;
@@ -57,8 +59,6 @@ using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
using StabilityMatrix.Core.Updater;
using Application = Avalonia.Application;
-using CheckpointFile = StabilityMatrix.Avalonia.ViewModels.CheckpointManager.CheckpointFile;
-using CheckpointFolder = StabilityMatrix.Avalonia.ViewModels.CheckpointManager.CheckpointFolder;
using LogLevel = Microsoft.Extensions.Logging.LogLevel;
namespace StabilityMatrix.Avalonia;
@@ -194,7 +194,12 @@ public sealed class App : Application
Services = services.BuildServiceProvider();
var settingsManager = Services.GetRequiredService();
- settingsManager.TryFindLibrary();
+
+ if (settingsManager.TryFindLibrary())
+ {
+ Cultures.TrySetSupportedCulture(settingsManager.Settings.Language);
+ }
+
Services.GetRequiredService().StartEventListener();
}
@@ -204,13 +209,15 @@ public sealed class App : Application
.AddSingleton()
.AddSingleton()
.AddSingleton()
+ .AddSingleton()
.AddSingleton()
.AddSingleton();
services.AddSingleton(provider =>
new MainWindowViewModel(provider.GetRequiredService(),
provider.GetRequiredService(),
- provider.GetRequiredService>())
+ provider.GetRequiredService>(),
+ provider.GetRequiredService())
{
Pages =
{
@@ -240,14 +247,16 @@ public sealed class App : Application
services.AddTransient();
services.AddTransient();
services.AddTransient();
+ services.AddTransient();
+
+ // Dialog view models (singleton)
services.AddSingleton();
services.AddSingleton();
// Other transients (usually sub view models)
- services.AddTransient()
- .AddTransient()
- .AddTransient();
-
+ services.AddTransient();
+ services.AddTransient();
+ services.AddTransient();
services.AddTransient();
// Global progress
@@ -273,7 +282,9 @@ public sealed class App : Application
.Register(provider.GetRequiredService)
.Register(provider.GetRequiredService)
.Register(provider.GetRequiredService)
- .Register(provider.GetRequiredService));
+ .Register(provider.GetRequiredService)
+ .Register(provider.GetRequiredService)
+ );
}
internal static void ConfigureViews(IServiceCollection services)
@@ -285,6 +296,7 @@ public sealed class App : Application
services.AddSingleton();
services.AddSingleton();
services.AddSingleton();
+ services.AddSingleton();
// Dialogs
services.AddTransient();
@@ -292,6 +304,7 @@ public sealed class App : Application
services.AddTransient();
services.AddTransient();
services.AddTransient();
+ services.AddTransient();
// Controls
services.AddTransient();
@@ -308,6 +321,7 @@ public sealed class App : Application
services.AddSingleton();
services.AddSingleton();
services.AddSingleton();
+ services.AddSingleton();
}
private static IServiceCollection ConfigureServices()
@@ -333,6 +347,10 @@ public sealed class App : Application
services.AddSingleton();
services.AddSingleton();
+ services.AddSingleton();
+ services.AddSingleton(provider =>
+ (IDisposable) provider.GetRequiredService());
+
// Rich presence
services.AddSingleton();
services.AddSingleton(provider =>
diff --git a/StabilityMatrix.Avalonia/Controls/AutoGrid.cs b/StabilityMatrix.Avalonia/Controls/AutoGrid.cs
new file mode 100644
index 00000000..df5c4735
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Controls/AutoGrid.cs
@@ -0,0 +1,408 @@
+// Modified from https://github.com/AvaloniaUI/AvaloniaAutoGrid
+/*The MIT License (MIT)
+
+Copyright (c) 2013 Charles Brown (carbonrobot)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.*/
+
+using System;
+using System.ComponentModel;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using Avalonia;
+using Avalonia.Controls;
+using Avalonia.Data;
+using Avalonia.Layout;
+
+namespace StabilityMatrix.Avalonia.Controls;
+
+///
+/// Defines a flexible grid area that consists of columns and rows.
+/// Depending on the orientation, either the rows or the columns are auto-generated,
+/// and the children's position is set according to their index.
+///
+[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+public class AutoGrid : Grid
+{
+ ///
+ /// Gets or sets the child horizontal alignment.
+ ///
+ /// The child horizontal alignment.
+ [Category("Layout"), Description("Presets the horizontal alignment of all child controls")]
+ public HorizontalAlignment? ChildHorizontalAlignment
+ {
+ get => (HorizontalAlignment?)GetValue(ChildHorizontalAlignmentProperty);
+ set => SetValue(ChildHorizontalAlignmentProperty, value);
+ }
+
+ ///
+ /// Gets or sets the child margin.
+ ///
+ /// The child margin.
+ [Category("Layout"), Description("Presets the margin of all child controls")]
+ public Thickness? ChildMargin
+ {
+ get => (Thickness?)GetValue(ChildMarginProperty);
+ set => SetValue(ChildMarginProperty, value);
+ }
+
+ ///
+ /// Gets or sets the child vertical alignment.
+ ///
+ /// The child vertical alignment.
+ [Category("Layout"), Description("Presets the vertical alignment of all child controls")]
+ public VerticalAlignment? ChildVerticalAlignment
+ {
+ get => (VerticalAlignment?)GetValue(ChildVerticalAlignmentProperty);
+ set => SetValue(ChildVerticalAlignmentProperty, value);
+ }
+
+ ///
+ /// Gets or sets the column count
+ ///
+ [Category("Layout"), Description("Defines a set number of columns")]
+ public int ColumnCount
+ {
+ get => (int)GetValue(ColumnCountProperty)!;
+ set => SetValue(ColumnCountProperty, value);
+ }
+
+ ///
+ /// Gets or sets the fixed column width
+ ///
+ [Category("Layout"), Description("Presets the width of all columns set using the ColumnCount property")]
+
+ public GridLength ColumnWidth
+ {
+ get => (GridLength)GetValue(ColumnWidthProperty)!;
+ set => SetValue(ColumnWidthProperty, value);
+ }
+
+ ///
+ /// Gets or sets a value indicating whether the children are automatically indexed.
+ ///
+ /// The default is true.
+ /// Note that if children are already indexed, setting this property to false will not remove their indices.
+ ///
+ ///
+ [Category("Layout"), Description("Set to false to disable the auto layout functionality")]
+ public bool IsAutoIndexing
+ {
+ get => (bool)GetValue(IsAutoIndexingProperty)!;
+ set => SetValue(IsAutoIndexingProperty, value);
+ }
+
+ ///
+ /// Gets or sets the orientation.
+ /// The default is Vertical.
+ ///
+ /// The orientation.
+ [Category("Layout"), Description("Defines the directionality of the autolayout. Use vertical for a column first layout, horizontal for a row first layout.")]
+ public Orientation Orientation
+ {
+ get => (Orientation)GetValue(OrientationProperty)!;
+ set => SetValue(OrientationProperty, value);
+ }
+
+ ///
+ /// Gets or sets the number of rows
+ ///
+ [Category("Layout"), Description("Defines a set number of rows")]
+ public int RowCount
+ {
+ get => (int)GetValue(RowCountProperty)!;
+ set => SetValue(RowCountProperty, value);
+ }
+
+ ///
+ /// Gets or sets the fixed row height
+ ///
+ [Category("Layout"), Description("Presets the height of all rows set using the RowCount property")]
+ public GridLength RowHeight
+ {
+ get => (GridLength)GetValue(RowHeightProperty)!;
+ set => SetValue(RowHeightProperty, value);
+ }
+
+ ///
+ /// Handles the column count changed event
+ ///
+ public static void ColumnCountChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ if ((int)e.NewValue! < 0)
+ return;
+
+ var grid = (AutoGrid)e.Sender;
+
+
+ // look for an existing column definition for the height
+ var width = grid.ColumnWidth;
+ if (!grid.IsSet(ColumnWidthProperty) && grid.ColumnDefinitions.Count > 0)
+ width = grid.ColumnDefinitions[0].Width;
+
+ // clear and rebuild
+ grid.ColumnDefinitions.Clear();
+ for (var i = 0; i < (int)e.NewValue; i++)
+ grid.ColumnDefinitions.Add(
+ new ColumnDefinition() { Width = width });
+ }
+
+ ///
+ /// Handle the fixed column width changed event
+ ///
+ public static void FixedColumnWidthChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ var grid = (AutoGrid)e.Sender;
+
+ // add a default column if missing
+ if (grid.ColumnDefinitions.Count == 0)
+ grid.ColumnDefinitions.Add(new ColumnDefinition());
+
+ // set all existing columns to this width
+ foreach (var t in grid.ColumnDefinitions)
+ t.Width = (GridLength)e.NewValue!;
+ }
+
+ ///
+ /// Handle the fixed row height changed event
+ ///
+ public static void FixedRowHeightChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ var grid = (AutoGrid)e.Sender;
+
+ // add a default row if missing
+ if (grid.RowDefinitions.Count == 0)
+ grid.RowDefinitions.Add(new RowDefinition());
+
+ // set all existing rows to this height
+ foreach (var t in grid.RowDefinitions)
+ t.Height = (GridLength)e.NewValue!;
+ }
+
+ ///
+ /// Handles the row count changed event
+ ///
+ public static void RowCountChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ if ((int)e.NewValue! < 0)
+ return;
+
+ var grid = (AutoGrid)e.Sender;
+
+ // look for an existing row to get the height
+ var height = grid.RowHeight;
+ if (!grid.IsSet(RowHeightProperty) && grid.RowDefinitions.Count > 0)
+ height = grid.RowDefinitions[0].Height;
+
+ // clear and rebuild
+ grid.RowDefinitions.Clear();
+ for (var i = 0; i < (int)e.NewValue; i++)
+ grid.RowDefinitions.Add(
+ new RowDefinition() { Height = height });
+ }
+
+ ///
+ /// Called when [child horizontal alignment changed].
+ ///
+ private static void OnChildHorizontalAlignmentChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ var grid = (AutoGrid)e.Sender;
+ foreach (var child in grid.Children)
+ {
+ child.SetValue(HorizontalAlignmentProperty,
+ grid.ChildHorizontalAlignment ?? AvaloniaProperty.UnsetValue);
+ }
+ }
+
+ ///
+ /// Called when [child layout changed].
+ ///
+ private static void OnChildMarginChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ var grid = (AutoGrid)e.Sender;
+ foreach (var child in grid.Children)
+ {
+ child.SetValue(MarginProperty, grid.ChildMargin ?? AvaloniaProperty.UnsetValue);
+ }
+ }
+
+ ///
+ /// Called when [child vertical alignment changed].
+ ///
+ private static void OnChildVerticalAlignmentChanged(AvaloniaPropertyChangedEventArgs e)
+ {
+ var grid = (AutoGrid)e.Sender;
+ foreach (var child in grid.Children)
+ {
+ child.SetValue(VerticalAlignmentProperty, grid.ChildVerticalAlignment ?? AvaloniaProperty.UnsetValue);
+ }
+ }
+
+ ///
+ /// Apply child margins and layout effects such as alignment
+ ///
+ private void ApplyChildLayout(Control child)
+ {
+ if (ChildMargin != null)
+ {
+ child.SetValue(MarginProperty, ChildMargin.Value, BindingPriority.Template);
+ }
+ if (ChildHorizontalAlignment != null)
+ {
+ child.SetValue(HorizontalAlignmentProperty, ChildHorizontalAlignment.Value, BindingPriority.Template);
+ }
+ if (ChildVerticalAlignment != null)
+ {
+ child.SetValue(VerticalAlignmentProperty, ChildVerticalAlignment.Value, BindingPriority.Template);
+ }
+ }
+
+ ///
+ /// Clamp a value to its maximum.
+ ///
+ private int Clamp(int value, int max)
+ {
+ return (value > max) ? max : value;
+ }
+
+ ///
+ /// Perform the grid layout of row and column indexes
+ ///
+ private void PerformLayout()
+ {
+ var fillRowFirst = Orientation == Orientation.Horizontal;
+ var rowCount = RowDefinitions.Count;
+ var colCount = ColumnDefinitions.Count;
+
+ if (rowCount == 0 || colCount == 0)
+ return;
+
+ var position = 0;
+ var skip = new bool[rowCount, colCount];
+ foreach (var child in Children.OfType())
+ {
+ var childIsCollapsed = !child.IsVisible;
+ if (IsAutoIndexing && !childIsCollapsed)
+ {
+ if (fillRowFirst)
+ {
+ var row = Clamp(position / colCount, rowCount - 1);
+ var col = Clamp(position % colCount, colCount - 1);
+ if (skip[row, col])
+ {
+ position++;
+ row = (position / colCount);
+ col = (position % colCount);
+ }
+
+ SetRow(child, row);
+ SetColumn(child, col);
+ position += GetColumnSpan(child);
+
+ var offset = GetRowSpan(child) - 1;
+ while (offset > 0)
+ {
+ skip[row + offset--, col] = true;
+ }
+ }
+ else
+ {
+ var row = Clamp(position % rowCount, rowCount - 1);
+ var col = Clamp(position / rowCount, colCount - 1);
+ if (skip[row, col])
+ {
+ position++;
+ row = position % rowCount;
+ col = position / rowCount;
+ }
+
+ SetRow(child, row);
+ SetColumn(child, col);
+ position += GetRowSpan(child);
+
+ var offset = GetColumnSpan(child) - 1;
+ while (offset > 0)
+ {
+ skip[row, col + offset--] = true;
+ }
+ }
+ }
+
+ ApplyChildLayout(child);
+ }
+ }
+
+ public static readonly AvaloniaProperty ChildHorizontalAlignmentProperty =
+ AvaloniaProperty.Register("ChildHorizontalAlignment");
+
+ public static readonly AvaloniaProperty ChildMarginProperty =
+ AvaloniaProperty.Register("ChildMargin");
+
+ public static readonly AvaloniaProperty ChildVerticalAlignmentProperty =
+ AvaloniaProperty.Register("ChildVerticalAlignment");
+
+ public static readonly AvaloniaProperty ColumnCountProperty =
+ AvaloniaProperty.RegisterAttached("ColumnCount", typeof(AutoGrid), 1);
+
+ public static readonly AvaloniaProperty ColumnWidthProperty =
+ AvaloniaProperty.RegisterAttached("ColumnWidth", typeof(AutoGrid), GridLength.Auto);
+
+ public static readonly AvaloniaProperty IsAutoIndexingProperty =
+ AvaloniaProperty.Register("IsAutoIndexing", true);
+
+ public static readonly AvaloniaProperty OrientationProperty =
+ AvaloniaProperty.Register("Orientation", Orientation.Vertical);
+
+ public static readonly AvaloniaProperty RowCountProperty =
+ AvaloniaProperty.RegisterAttached("RowCount", typeof(AutoGrid), 1);
+
+ public static readonly AvaloniaProperty RowHeightProperty =
+ AvaloniaProperty.RegisterAttached("RowHeight", typeof(AutoGrid), GridLength.Auto);
+
+ static AutoGrid()
+ {
+ AffectsMeasure(ChildHorizontalAlignmentProperty, ChildMarginProperty,
+ ChildVerticalAlignmentProperty, ColumnCountProperty, ColumnWidthProperty, IsAutoIndexingProperty, OrientationProperty,
+ RowHeightProperty);
+
+ ChildHorizontalAlignmentProperty.Changed.Subscribe(OnChildHorizontalAlignmentChanged);
+ ChildMarginProperty.Changed.Subscribe(OnChildMarginChanged);
+ ChildVerticalAlignmentProperty.Changed.Subscribe(OnChildVerticalAlignmentChanged);
+ ColumnCountProperty.Changed.Subscribe(ColumnCountChanged);
+ RowCountProperty.Changed.Subscribe(RowCountChanged);
+ ColumnWidthProperty.Changed.Subscribe(FixedColumnWidthChanged);
+ RowHeightProperty.Changed.Subscribe(FixedRowHeightChanged);
+ }
+
+ #region Overrides
+
+ ///
+ /// Measures the children of a in anticipation of arranging them during the pass.
+ ///
+ /// Indicates an upper limit size that should not be exceeded.
+ ///
+ /// that represents the required size to arrange child content.
+ ///
+ protected override Size MeasureOverride(Size constraint)
+ {
+ PerformLayout();
+ return base.MeasureOverride(constraint);
+ }
+
+ #endregion Overrides
+}
diff --git a/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs b/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs
index 33511dbc..732715a3 100644
--- a/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs
+++ b/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs
@@ -132,6 +132,15 @@ public class BetterContentDialog : ContentDialog
get => GetValue(MaxDialogHeightProperty);
set => SetValue(MaxDialogHeightProperty, value);
}
+
+ public static readonly StyledProperty ContentMarginProperty = AvaloniaProperty.Register(
+ "ContentMargin");
+
+ public Thickness ContentMargin
+ {
+ get => GetValue(ContentMarginProperty);
+ set => SetValue(ContentMarginProperty, value);
+ }
public BetterContentDialog()
@@ -205,6 +214,18 @@ public class BetterContentDialog : ContentDialog
TryBindButtons();
}
+ ///
+ protected override void OnApplyTemplate(TemplateAppliedEventArgs e)
+ {
+ base.OnApplyTemplate(e);
+
+ var background = e.NameScope.Find("BackgroundElement");
+ if (background is not null)
+ {
+ background.Margin = ContentMargin;
+ }
+ }
+
private void OnLoaded(object? sender, RoutedEventArgs? e)
{
TryBindButtons();
diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs
index b0fc04ba..69b8d7f0 100644
--- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs
+++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs
@@ -4,7 +4,6 @@ using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.IO;
-using System.Linq;
using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using StabilityMatrix.Avalonia.Models;
@@ -12,8 +11,8 @@ using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
+using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
-using StabilityMatrix.Avalonia.ViewModels.PackageManager;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Helper;
@@ -25,8 +24,6 @@ using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
using StabilityMatrix.Core.Updater;
-using CheckpointFile = StabilityMatrix.Avalonia.ViewModels.CheckpointManager.CheckpointFile;
-using CheckpointFolder = StabilityMatrix.Avalonia.ViewModels.CheckpointManager.CheckpointFolder;
namespace StabilityMatrix.Avalonia.DesignData;
@@ -91,7 +88,8 @@ public static class DesignData
.AddSingleton()
.AddSingleton()
.AddSingleton()
- .AddSingleton();
+ .AddSingleton()
+ .AddSingleton();
// Placeholder services that nobody should need during design time
services
@@ -192,6 +190,11 @@ public static class DesignData
{
Title = "StableDiffusion",
DirectoryPath = "Packages/Lora/Subfolder",
+ },
+ new(settingsManager, downloadService, modelFinder)
+ {
+ Title = "Lora",
+ DirectoryPath = "Packages/StableDiffusion/Subfolder",
}
},
CheckpointFiles = new AdvancedObservableList
@@ -223,14 +226,42 @@ public static class DesignData
})
};
- ProgressManagerViewModel.ProgressItems = new ObservableCollection
+ NewCheckpointsPageViewModel.AllCheckpoints = new ObservableCollection
{
- new(new ProgressItem(Guid.NewGuid(), "Test File.exe",
- new ProgressReport(0.5f, "Downloading..."))),
- new(new ProgressItem(Guid.NewGuid(), "Test File 2.uwu",
- new ProgressReport(0.25f, "Extracting...")))
+ new()
+ {
+ FilePath = "~/Models/StableDiffusion/electricity-light.safetensors",
+ Title = "Auroral Background",
+ PreviewImagePath = "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/" +
+ "78fd2a0a-42b6-42b0-9815-81cb11bb3d05/00009-2423234823.jpeg",
+ ConnectedModel = new ConnectedModelInfo
+ {
+ VersionName = "Lightning Auroral",
+ BaseModel = "SD 1.5",
+ ModelName = "Auroral Background",
+ ModelType = CivitModelType.Model,
+ FileMetadata = new CivitFileMetadata
+ {
+ Format = CivitModelFormat.SafeTensor,
+ Fp = CivitModelFpType.fp16,
+ Size = CivitModelSize.pruned,
+ }
+ }
+ },
+ new()
+ {
+ FilePath = "~/Models/Lora/model.safetensors",
+ Title = "Some model"
+ }
};
+ ProgressManagerViewModel.ProgressItems.AddRange(new ProgressItemViewModelBase[]
+ {
+ new ProgressItemViewModel(new ProgressItem(Guid.NewGuid(), "Test File.exe",
+ new ProgressReport(0.5f, "Downloading..."))),
+ new MockDownloadProgressItemViewModel("Test File 2.exe"),
+ });
+
UpdateViewModel = Services.GetRequiredService();
UpdateViewModel.UpdateText =
$"Stability Matrix v2.0.1 is now available! You currently have v2.0.0. Would you like to update now?";
@@ -262,10 +293,15 @@ public static class DesignData
{
var settings = Services.GetRequiredService();
var vm = Services.GetRequiredService();
- vm.Packages = new ObservableCollection(
- settings.Settings.InstalledPackages.Select(p =>
- DialogFactory.Get(viewModel => viewModel.Package = p)));
- vm.Packages.First().IsUpdateAvailable = true;
+
+ vm.SetPackages(settings.Settings.InstalledPackages);
+ vm.SetUnknownPackages(new InstalledPackage[]
+ {
+ UnknownInstalledPackage.FromDirectoryName("sd-unknown"),
+ });
+
+ vm.PackageCards[0].IsUpdateAvailable = true;
+
return vm;
}
}
@@ -273,6 +309,9 @@ public static class DesignData
public static CheckpointsPageViewModel CheckpointsPageViewModel =>
Services.GetRequiredService();
+ public static NewCheckpointsPageViewModel NewCheckpointsPageViewModel =>
+ Services.GetRequiredService();
+
public static SettingsViewModel SettingsViewModel =>
Services.GetRequiredService();
@@ -368,6 +407,9 @@ public static class DesignData
};
});
+ public static PackageImportViewModel PackageImportViewModel =>
+ DialogFactory.Get();
+
public static RefreshBadgeViewModel RefreshBadgeViewModel => new()
{
State = ProgressState.Success
diff --git a/StabilityMatrix.Avalonia/DesignData/MockDownloadProgressItemViewModel.cs b/StabilityMatrix.Avalonia/DesignData/MockDownloadProgressItemViewModel.cs
new file mode 100644
index 00000000..5f94c91d
--- /dev/null
+++ b/StabilityMatrix.Avalonia/DesignData/MockDownloadProgressItemViewModel.cs
@@ -0,0 +1,65 @@
+using System.Threading;
+using System.Threading.Tasks;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Models.Progress;
+
+namespace StabilityMatrix.Avalonia.DesignData;
+
+public class MockDownloadProgressItemViewModel : PausableProgressItemViewModelBase
+{
+ private Task? dummyTask;
+ private CancellationTokenSource? cts;
+
+ public MockDownloadProgressItemViewModel(string fileName)
+ {
+ Name = fileName;
+ Progress.Value = 5;
+ Progress.IsIndeterminate = false;
+ Progress.Text = "Downloading...";
+ }
+
+ ///
+ public override Task Cancel()
+ {
+ // Cancel the task that updates progress
+ cts?.Cancel();
+ cts = null;
+ dummyTask = null;
+
+ State = ProgressState.Cancelled;
+ Progress.Text = "Cancelled";
+ return Task.CompletedTask;
+ }
+
+ ///
+ public override Task Pause()
+ {
+ // Cancel the task that updates progress
+ cts?.Cancel();
+ cts = null;
+ dummyTask = null;
+
+ State = ProgressState.Inactive;
+
+ return Task.CompletedTask;
+ }
+
+ ///
+ public override Task Resume()
+ {
+ // Start a task that updates progress every 100ms
+ cts = new CancellationTokenSource();
+ dummyTask = Task.Run(async () =>
+ {
+ while (State != ProgressState.Success)
+ {
+ await Task.Delay(100, cts.Token);
+ Progress.Value += 1;
+ }
+ }, cts.Token);
+
+ State = ProgressState.Working;
+
+ return Task.CompletedTask;
+ }
+}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs
index 419d972e..f3b35519 100644
--- a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs
+++ b/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs
@@ -1,5 +1,6 @@
using System;
using System.IO;
+using System.Threading;
using System.Threading.Tasks;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Services;
@@ -8,8 +9,16 @@ namespace StabilityMatrix.Avalonia.DesignData;
public class MockDownloadService : IDownloadService
{
- public Task DownloadToFileAsync(string downloadUrl, string downloadPath,
- IProgress? progress = null, string? httpClientName = null)
+ public Task DownloadToFileAsync(string downloadUrl, string downloadPath, IProgress? progress = null,
+ string? httpClientName = null, CancellationToken cancellationToken = default)
+ {
+ return Task.CompletedTask;
+ }
+
+ ///
+ public Task ResumeDownloadToFileAsync(string downloadUrl, string downloadPath, long existingFileSize,
+ IProgress? progress = null, string? httpClientName = null,
+ CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
diff --git a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs
new file mode 100644
index 00000000..4522bd2d
--- /dev/null
+++ b/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs
@@ -0,0 +1,22 @@
+using System;
+using System.Collections.Generic;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Avalonia.DesignData;
+
+public class MockTrackedDownloadService : ITrackedDownloadService
+{
+ ///
+ public IEnumerable Downloads => Array.Empty();
+
+ ///
+ public event EventHandler? DownloadAdded;
+
+ ///
+ public TrackedDownload NewDownload(Uri downloadUrl, FilePath downloadPath)
+ {
+ throw new NotImplementedException();
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Extensions/DirectoryPathExtensions.cs b/StabilityMatrix.Avalonia/Extensions/DirectoryPathExtensions.cs
new file mode 100644
index 00000000..508d0818
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Extensions/DirectoryPathExtensions.cs
@@ -0,0 +1,91 @@
+using System;
+using System.Diagnostics.CodeAnalysis;
+using System.IO;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
+using Polly;
+using StabilityMatrix.Core.Models.FileInterfaces;
+
+namespace StabilityMatrix.Avalonia.Extensions;
+
+[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+public static class DirectoryPathExtensions
+{
+ ///
+ /// Deletes a directory and all of its contents recursively.
+ /// Uses Polly to retry the deletion if it fails, up to 5 times with an exponential backoff.
+ ///
+ public static Task DeleteVerboseAsync(this DirectoryPath directory, ILogger? logger = default)
+ {
+ var policy = Policy.Handle()
+ .WaitAndRetryAsync(3, attempt => TimeSpan.FromMilliseconds(50 * Math.Pow(2, attempt)),
+ onRetry: (exception, calculatedWaitDuration) =>
+ {
+ logger?.LogWarning(
+ exception,
+ "Deletion of {TargetDirectory} failed. Retrying in {CalculatedWaitDuration}",
+ directory, calculatedWaitDuration);
+ });
+
+ return policy.ExecuteAsync(async () =>
+ {
+ await Task.Run(() => { DeleteVerbose(directory, logger); });
+ });
+ }
+
+ ///
+ /// Deletes a directory and all of its contents recursively.
+ /// Removes link targets without deleting the source.
+ ///
+ public static void DeleteVerbose(this DirectoryPath directory, ILogger? logger = default)
+ {
+ // Skip if directory does not exist
+ if (!directory.Exists)
+ {
+ return;
+ }
+ // For junction points, delete with recursive false
+ if (directory.IsSymbolicLink)
+ {
+ logger?.LogInformation("Removing junction point {TargetDirectory}", directory);
+ try
+ {
+ directory.Delete(false);
+ return;
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to delete junction point {directory}", ex);
+ }
+ }
+ // Recursively delete all subdirectories
+ foreach (var subDir in directory.Info.EnumerateDirectories())
+ {
+ DeleteVerbose(subDir, logger);
+ }
+
+ // Delete all files in the directory
+ foreach (var filePath in directory.Info.EnumerateFiles())
+ {
+ try
+ {
+ filePath.Attributes = FileAttributes.Normal;
+ filePath.Delete();
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to delete file {filePath.FullName}", ex);
+ }
+ }
+
+ // Delete this directory
+ try
+ {
+ directory.Delete(false);
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to delete directory {directory}", ex);
+ }
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Languages/Cultures.cs b/StabilityMatrix.Avalonia/Languages/Cultures.cs
new file mode 100644
index 00000000..542fd18b
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Languages/Cultures.cs
@@ -0,0 +1,52 @@
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
+
+namespace StabilityMatrix.Avalonia.Languages;
+
+[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+public static class Cultures
+{
+ public static CultureInfo Default { get; } = new("en-US");
+
+ public static CultureInfo Current => Resources.Culture;
+
+ public static readonly Dictionary SupportedCulturesByCode =
+ new Dictionary
+ {
+ ["en-US"] = Default,
+ ["ja-JP"] = new("ja-JP")
+ };
+
+ public static IReadOnlyList SupportedCultures
+ => SupportedCulturesByCode.Values.ToImmutableList();
+
+ public static CultureInfo GetSupportedCultureOrDefault(string? cultureCode)
+ {
+ if (cultureCode is null
+ || !SupportedCulturesByCode.TryGetValue(cultureCode, out var culture))
+ {
+ return Default;
+ }
+
+ return culture;
+ }
+
+ public static bool TrySetSupportedCulture(string? cultureCode)
+ {
+ if (cultureCode is null
+ || !SupportedCulturesByCode.TryGetValue(cultureCode, out var culture))
+ {
+ return false;
+ }
+
+ Resources.Culture = culture;
+ return true;
+ }
+
+ public static bool TrySetSupportedCulture(CultureInfo? cultureInfo)
+ {
+ return cultureInfo is not null && TrySetSupportedCulture(cultureInfo.Name);
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs b/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs
new file mode 100644
index 00000000..27757b99
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Languages/Resources.Designer.cs
@@ -0,0 +1,206 @@
+//------------------------------------------------------------------------------
+//
+// This code was generated by a tool.
+//
+// Changes to this file may cause incorrect behavior and will be lost if
+// the code is regenerated.
+//
+//------------------------------------------------------------------------------
+
+namespace StabilityMatrix.Avalonia.Languages {
+ using System;
+
+
+ ///
+ /// A strongly-typed resource class, for looking up localized strings, etc.
+ ///
+ // This class was auto-generated by the StronglyTypedResourceBuilder
+ // class via a tool like ResGen or Visual Studio.
+ // To add or remove a member, edit your .ResX file then rerun ResGen
+ // with the /str option, or rebuild your VS project.
+ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")]
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute()]
+ [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()]
+ public class Resources {
+
+ private static global::System.Resources.ResourceManager resourceMan;
+
+ private static global::System.Globalization.CultureInfo resourceCulture;
+
+ [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")]
+ internal Resources() {
+ }
+
+ ///
+ /// Returns the cached ResourceManager instance used by this class.
+ ///
+ [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)]
+ public static global::System.Resources.ResourceManager ResourceManager {
+ get {
+ if (object.ReferenceEquals(resourceMan, null)) {
+ global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("StabilityMatrix.Avalonia.Languages.Resources", typeof(Resources).Assembly);
+ resourceMan = temp;
+ }
+ return resourceMan;
+ }
+ }
+
+ ///
+ /// Overrides the current thread's CurrentUICulture property for all
+ /// resource lookups using this strongly typed resource class.
+ ///
+ [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)]
+ public static global::System.Globalization.CultureInfo Culture {
+ get {
+ return resourceCulture;
+ }
+ set {
+ resourceCulture = value;
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Cancel.
+ ///
+ public static string Action_Cancel {
+ get {
+ return ResourceManager.GetString("Action_Cancel", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Import.
+ ///
+ public static string Action_Import {
+ get {
+ return ResourceManager.GetString("Action_Import", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Launch.
+ ///
+ public static string Action_Launch {
+ get {
+ return ResourceManager.GetString("Action_Launch", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Quit.
+ ///
+ public static string Action_Quit {
+ get {
+ return ResourceManager.GetString("Action_Quit", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Relaunch.
+ ///
+ public static string Action_Relaunch {
+ get {
+ return ResourceManager.GetString("Action_Relaunch", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Relaunch Later.
+ ///
+ public static string Action_RelaunchLater {
+ get {
+ return ResourceManager.GetString("Action_RelaunchLater", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Save.
+ ///
+ public static string Action_Save {
+ get {
+ return ResourceManager.GetString("Action_Save", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Branches.
+ ///
+ public static string Label_Branches {
+ get {
+ return ResourceManager.GetString("Label_Branches", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Language.
+ ///
+ public static string Label_Language {
+ get {
+ return ResourceManager.GetString("Label_Language", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Package Type.
+ ///
+ public static string Label_PackageType {
+ get {
+ return ResourceManager.GetString("Label_PackageType", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Relaunch Required.
+ ///
+ public static string Label_RelaunchRequired {
+ get {
+ return ResourceManager.GetString("Label_RelaunchRequired", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Releases.
+ ///
+ public static string Label_Releases {
+ get {
+ return ResourceManager.GetString("Label_Releases", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Unknown Package.
+ ///
+ public static string Label_UnknownPackage {
+ get {
+ return ResourceManager.GetString("Label_UnknownPackage", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Version.
+ ///
+ public static string Label_Version {
+ get {
+ return ResourceManager.GetString("Label_Version", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Version Type.
+ ///
+ public static string Label_VersionType {
+ get {
+ return ResourceManager.GetString("Label_VersionType", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Relaunch is required for new language option to take effect.
+ ///
+ public static string Text_RelaunchRequiredToApplyLanguage {
+ get {
+ return ResourceManager.GetString("Text_RelaunchRequiredToApplyLanguage", resourceCulture);
+ }
+ }
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Languages/Resources.ja-JP.resx b/StabilityMatrix.Avalonia/Languages/Resources.ja-JP.resx
new file mode 100644
index 00000000..ccd583bf
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Languages/Resources.ja-JP.resx
@@ -0,0 +1,23 @@
+
+
+ text/microsoft-resx
+
+
+ 1.3
+
+
+ System.Resources.ResXResourceReader, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
+
+
+ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
+
+
+ 保存
+
+
+ 戻る
+
+
+ 言語
+
+
diff --git a/StabilityMatrix.Avalonia/Languages/Resources.resx b/StabilityMatrix.Avalonia/Languages/Resources.resx
new file mode 100644
index 00000000..93626903
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Languages/Resources.resx
@@ -0,0 +1,69 @@
+
+
+
+
+
+
+
+
+
+ text/microsoft-resx
+
+
+ 1.3
+
+
+ System.Resources.ResXResourceReader, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
+
+
+ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
+
+
+ Launch
+
+
+ Quit
+
+
+ Save
+
+
+ Cancel
+
+
+ Language
+
+
+ Relaunch is required for new language option to take effect
+
+
+ Relaunch
+
+
+ Relaunch Later
+
+
+ Relaunch Required
+
+
+ Unknown Package
+
+
+ Import
+
+
+ Package Type
+
+
+ Version
+
+
+ Version Type
+
+
+ Releases
+
+
+ Branches
+
+
diff --git a/StabilityMatrix.Avalonia/Services/ServiceManager.cs b/StabilityMatrix.Avalonia/Services/ServiceManager.cs
index dc7387cf..ee957cec 100644
--- a/StabilityMatrix.Avalonia/Services/ServiceManager.cs
+++ b/StabilityMatrix.Avalonia/Services/ServiceManager.cs
@@ -1,10 +1,14 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
+using Avalonia.Controls;
using Microsoft.Extensions.DependencyInjection;
+using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Services;
+[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public class ServiceManager
{
// Holds providers
@@ -111,6 +115,48 @@ public class ServiceManager
$"Service of type {typeof(TService)} is not registered for {typeof(T)}");
}
+ ///
+ /// Get a view model instance from runtime type
+ ///
+ [SuppressMessage("ReSharper", "InconsistentlySynchronizedField")]
+ public T Get(Type serviceType)
+ {
+ if (!serviceType.IsAssignableFrom(typeof(T)))
+ {
+ throw new ArgumentException(
+ $"Service type {serviceType} is not assignable from {typeof(T)}");
+ }
+
+ if (instances.TryGetValue(serviceType, out var instance))
+ {
+ if (instance is null)
+ {
+ throw new ArgumentException(
+ $"Service of type {serviceType} was registered as null");
+ }
+ return (T) instance;
+ }
+
+ if (providers.TryGetValue(serviceType, out var provider))
+ {
+ if (provider is null)
+ {
+ throw new ArgumentException(
+ $"Service of type {serviceType} was registered as null");
+ }
+ var result = provider();
+ if (result is null)
+ {
+ throw new ArgumentException(
+ $"Service provider for type {serviceType} returned null");
+ }
+ return (T) result;
+ }
+
+ throw new ArgumentException(
+ $"Service of type {serviceType} is not registered for {typeof(T)}");
+ }
+
///
/// Get a view model instance with an initializer parameter
///
@@ -129,4 +175,50 @@ public class ServiceManager
initializer(instance);
return instance;
}
+
+ ///
+ /// Get a view model instance, set as DataContext of its View, and return
+ /// a BetterContentDialog with that View as its Content
+ ///
+ public BetterContentDialog GetDialog() where TService : T
+ {
+ var instance = Get()!;
+
+ if (Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute)) is not ViewAttribute
+ viewAttr)
+ {
+ throw new InvalidOperationException($"View not found for {instance.GetType().FullName}");
+ }
+
+ if (Activator.CreateInstance(viewAttr.GetViewType()) is not Control view)
+ {
+ throw new NullReferenceException($"Unable to create instance for {instance.GetType().FullName}");
+ }
+
+ return new BetterContentDialog { Content = view };
+ }
+
+ ///
+ /// Get a view model instance with initializer, set as DataContext of its View, and return
+ /// a BetterContentDialog with that View as its Content
+ ///
+ public BetterContentDialog GetDialog(Action initializer) where TService : T
+ {
+ var instance = Get(initializer)!;
+
+ if (Attribute.GetCustomAttribute(instance.GetType(), typeof(ViewAttribute)) is not ViewAttribute
+ viewAttr)
+ {
+ throw new InvalidOperationException($"View not found for {instance.GetType().FullName}");
+ }
+
+ if (Activator.CreateInstance(viewAttr.GetViewType()) is not Control view)
+ {
+ throw new NullReferenceException($"Unable to create instance for {instance.GetType().FullName}");
+ }
+
+ view.DataContext = instance;
+
+ return new BetterContentDialog { Content = view };
+ }
}
diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
index 2acc5271..bad8f426 100644
--- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
+++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
@@ -8,7 +8,7 @@
app.manifesttrue./Assets/Icon.ico
- 2.1.1-dev.1
+ 2.3.0-dev.1$(Version)true
@@ -25,6 +25,7 @@
+
@@ -76,4 +77,19 @@
+
+
+
+ PublicResXFileCodeGenerator
+ Resources.Designer.cs
+
+
+
+
+
+ True
+ True
+ Resources.resx
+
+
diff --git a/StabilityMatrix.Avalonia/Styles/ButtonStyles.axaml b/StabilityMatrix.Avalonia/Styles/ButtonStyles.axaml
index ab433d39..18de2349 100644
--- a/StabilityMatrix.Avalonia/Styles/ButtonStyles.axaml
+++ b/StabilityMatrix.Avalonia/Styles/ButtonStyles.axaml
@@ -11,6 +11,7 @@
+
@@ -302,4 +303,45 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Styles/ToggleButtonStyles.axaml b/StabilityMatrix.Avalonia/Styles/ToggleButtonStyles.axaml
new file mode 100644
index 00000000..2557dcea
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Styles/ToggleButtonStyles.axaml
@@ -0,0 +1,347 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/PausableProgressItemViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/PausableProgressItemViewModelBase.cs
new file mode 100644
index 00000000..2d0f2672
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/PausableProgressItemViewModelBase.cs
@@ -0,0 +1,49 @@
+using System.Diagnostics.CodeAnalysis;
+using System.Threading.Tasks;
+using CommunityToolkit.Mvvm.ComponentModel;
+using CommunityToolkit.Mvvm.Input;
+using StabilityMatrix.Core.Models.Progress;
+
+namespace StabilityMatrix.Avalonia.ViewModels.Base;
+
+[SuppressMessage("ReSharper", "VirtualMemberNeverOverridden.Global")]
+public abstract partial class PausableProgressItemViewModelBase : ProgressItemViewModelBase
+{
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(IsPaused), nameof(IsCompleted), nameof(CanPauseResume), nameof(CanCancel))]
+ private ProgressState state = ProgressState.Inactive;
+
+ ///
+ /// Whether the progress is paused
+ ///
+ public bool IsPaused => State == ProgressState.Inactive;
+
+ ///
+ /// Whether the progress has succeeded, failed or was cancelled
+ ///
+ public override bool IsCompleted => State is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled;
+
+ public virtual bool SupportsPauseResume => true;
+ public virtual bool SupportsCancel => true;
+
+ public bool CanPauseResume => SupportsPauseResume && !IsCompleted;
+ public bool CanCancel => SupportsCancel && !IsCompleted;
+
+ private AsyncRelayCommand? pauseCommand;
+ public IAsyncRelayCommand PauseCommand => pauseCommand ??= new AsyncRelayCommand(Pause);
+ public virtual Task Pause() => Task.CompletedTask;
+
+ private AsyncRelayCommand? resumeCommand;
+ public IAsyncRelayCommand ResumeCommand => resumeCommand ??= new AsyncRelayCommand(Resume);
+ public virtual Task Resume() => Task.CompletedTask;
+
+ private AsyncRelayCommand? cancelCommand;
+ public IAsyncRelayCommand CancelCommand => cancelCommand ??= new AsyncRelayCommand(Cancel);
+ public virtual Task Cancel() => Task.CompletedTask;
+
+ [RelayCommand]
+ private Task TogglePauseResume()
+ {
+ return IsPaused ? Resume() : Pause();
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs
new file mode 100644
index 00000000..a304da5b
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/Base/ProgressItemViewModelBase.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Threading.Tasks;
+using CommunityToolkit.Mvvm.ComponentModel;
+
+namespace StabilityMatrix.Avalonia.ViewModels.Base;
+
+public abstract partial class ProgressItemViewModelBase : ViewModelBase
+{
+ [ObservableProperty] private Guid id;
+ [ObservableProperty] private string? name;
+ [ObservableProperty] private bool failed;
+
+ public virtual bool IsCompleted => Progress.Value >= 100 || Failed;
+
+ public ProgressViewModel Progress { get; } = new();
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs
index 3b7f0120..e9b2bd86 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CheckpointBrowserCardViewModel.cs
@@ -32,10 +32,10 @@ using Notification = Avalonia.Controls.Notifications.Notification;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel
-
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly IDownloadService downloadService;
+ private readonly ITrackedDownloadService trackedDownloadService;
private readonly ISettingsManager settingsManager;
private readonly ServiceManager dialogFactory;
private readonly INotificationService notificationService;
@@ -63,11 +63,13 @@ public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel
public CheckpointBrowserCardViewModel(
IDownloadService downloadService,
+ ITrackedDownloadService trackedDownloadService,
ISettingsManager settingsManager,
ServiceManager dialogFactory,
INotificationService notificationService)
{
this.downloadService = downloadService;
+ this.trackedDownloadService = trackedDownloadService;
this.settingsManager = settingsManager;
this.dialogFactory = dialogFactory;
this.notificationService = notificationService;
@@ -197,6 +199,53 @@ public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel
await DoImport(model, selectedVersion, selectedFile);
}
+ private static async Task SaveCmInfo(
+ CivitModel model,
+ CivitModelVersion modelVersion,
+ CivitFile modelFile,
+ DirectoryPath downloadDirectory)
+ {
+ var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Name);
+ var modelInfo =
+ new ConnectedModelInfo(model, modelVersion, modelFile, DateTime.UtcNow);
+
+ await modelInfo.SaveJsonToDirectory(downloadDirectory, modelFileName);
+
+ var jsonName = $"{modelFileName}.cm-info.json";
+ return downloadDirectory.JoinFile(jsonName);
+ }
+
+ ///
+ /// Saves the preview image to the same directory as the model file
+ ///
+ ///
+ ///
+ /// The file path of the saved preview image
+ private async Task SavePreviewImage(CivitModelVersion modelVersion, FilePath modelFilePath)
+ {
+ // Skip if model has no images
+ if (modelVersion.Images == null || modelVersion.Images.Count == 0)
+ {
+ return null;
+ }
+
+ var image = modelVersion.Images[0];
+ var imageExtension = Path.GetExtension(image.Url).TrimStart('.');
+ if (imageExtension is "jpg" or "jpeg" or "png")
+ {
+ var imageDownloadPath =
+ modelFilePath.Directory!.JoinFile($"{modelFilePath.Name}.preview.{imageExtension}");
+
+ var imageTask =
+ downloadService.DownloadToFileAsync(image.Url, imageDownloadPath);
+ await notificationService.TryAsync(imageTask, "Could not download preview image");
+
+ return imageDownloadPath;
+ }
+
+ return null;
+ }
+
private async Task DoImport(CivitModel model, CivitModelVersion? selectedVersion = null,
CivitFile? selectedFile = null)
{
@@ -204,164 +253,96 @@ public partial class CheckpointBrowserCardViewModel : Base.ProgressViewModel
Text = "Downloading...";
OnDownloadStart?.Invoke(this);
+
+ // Get latest version
+ var modelVersion = selectedVersion ?? model.ModelVersions?.FirstOrDefault();
+ if (modelVersion is null)
+ {
+ notificationService.Show(new Notification("Model has no versions available",
+ "This model has no versions available for download", NotificationType.Warning));
+ Text = "Unable to Download";
+ return;
+ }
- // Holds files to be deleted on errors
- var filesForCleanup = new HashSet();
-
- // Set Text when exiting, finally block will set 100 and delay clear progress
- try
+ // Get latest version file
+ var modelFile = selectedFile ??
+ modelVersion.Files?.FirstOrDefault(x => x.Type == CivitFileType.Model);
+ if (modelFile is null)
{
- // Get latest version
- var modelVersion = selectedVersion ?? model.ModelVersions?.FirstOrDefault();
- if (modelVersion is null)
- {
- notificationService.Show(new Notification("Model has no versions available",
- "This model has no versions available for download", NotificationType.Warning));
- Text = "Unable to Download";
- return;
- }
+ notificationService.Show(new Notification("Model has no files available",
+ "This model has no files available for download", NotificationType.Warning));
+ Text = "Unable to Download";
+ return;
+ }
+
+ var rootModelsDirectory = new DirectoryPath(settingsManager.ModelsDirectory);
+
+ var downloadDirectory =
+ rootModelsDirectory.JoinDir(model.Type.ConvertTo()
+ .GetStringValue());
+ // Folders might be missing if user didn't install any packages yet
+ downloadDirectory.Create();
+
+ var downloadPath = downloadDirectory.JoinFile(modelFile.Name);
- // Get latest version file
- var modelFile = selectedFile ??
- modelVersion.Files?.FirstOrDefault(x => x.Type == CivitFileType.Model);
- if (modelFile is null)
+ // Download model info and preview first
+ var cmInfoPath = await SaveCmInfo(model, modelVersion, modelFile, downloadDirectory);
+ var previewImagePath = await SavePreviewImage(modelVersion, downloadPath);
+
+ // Create tracked download
+ var download = trackedDownloadService.NewDownload(modelFile.DownloadUrl, downloadPath);
+
+ // Add hash info
+ download.ExpectedHashSha256 = modelFile.Hashes.SHA256;
+
+ // Add files to cleanup list
+ download.ExtraCleanupFileNames.Add(cmInfoPath);
+ if (previewImagePath is not null)
+ {
+ download.ExtraCleanupFileNames.Add(previewImagePath);
+ }
+
+ // Attach for progress updates
+ download.ProgressUpdate += (s, e) =>
+ {
+ Value = e.Percentage;
+ if (e.Type == ProgressType.Hashing)
{
- notificationService.Show(new Notification("Model has no files available",
- "This model has no files available for download", NotificationType.Warning));
- Text = "Unable to Download";
- return;
+ Text = $"Validating... {e.Percentage}%";
}
-
- var downloadFolder = Path.Combine(settingsManager.ModelsDirectory,
- model.Type.ConvertTo().GetStringValue());
- // Folders might be missing if user didn't install any packages yet
- Directory.CreateDirectory(downloadFolder);
- var downloadPath = Path.GetFullPath(Path.Combine(downloadFolder, modelFile.Name));
- filesForCleanup.Add(downloadPath);
-
- // Do the download
- var progressId = Guid.NewGuid();
- var downloadTask = downloadService.DownloadToFileAsync(modelFile.DownloadUrl,
- downloadPath,
- new Progress(report =>
- {
- if (Math.Abs(report.Percentage - Value) > 0.1)
- {
- Dispatcher.UIThread.Post(() =>
- {
- Value = report.Percentage;
- Text = $"Downloading... {report.Percentage}%";
- });
- EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- modelFile.Name, report));
- }
- }));
-
- var downloadResult =
- await notificationService.TryAsync(downloadTask, "Could not download file");
-
- // Failed download handling
- if (downloadResult.Exception is not null)
+ else
{
- // For exceptions other than ApiException or TaskCanceledException, log error
- var logLevel = downloadResult.Exception switch
- {
- HttpRequestException or ApiException or TaskCanceledException => LogLevel.Warn,
- _ => LogLevel.Error
- };
- Logger.Log(logLevel, downloadResult.Exception, "Error during model download");
-
- Text = "Download Failed";
- EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- modelFile.Name, new ProgressReport(0f), true));
- return;
+ Text = $"Downloading... {e.Percentage}%";
}
+ };
- // When sha256 is available, validate the downloaded file
- var fileExpectedSha256 = modelFile.Hashes.SHA256;
- if (!string.IsNullOrEmpty(fileExpectedSha256))
+ download.ProgressStateChanged += (s, e) =>
+ {
+ if (e == ProgressState.Success)
{
- var hashProgress = new Progress(progress =>
- {
- Value = progress.Percentage;
- Text = $"Validating... {progress.Percentage}%";
- EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- modelFile.Name, progress));
- });
- var sha256 = await FileHash.GetSha256Async(downloadPath, hashProgress);
- if (sha256 != fileExpectedSha256.ToLowerInvariant())
- {
- Text = "Import Failed!";
- DelayedClearProgress(TimeSpan.FromMilliseconds(800));
- notificationService.Show(new Notification("Download failed hash validation",
- "This may be caused by network or server issues from CivitAI, please try again in a few minutes.",
- NotificationType.Error));
- Text = "Download Failed";
-
- EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- modelFile.Name, new ProgressReport(0f), true));
- return;
- }
-
- settingsManager.Transaction(
- s => s.InstalledModelHashes.Add(modelFile.Hashes.BLAKE3));
-
- notificationService.Show(new Notification("Import complete",
- $"{model.Type} {model.Name} imported successfully!", NotificationType.Success));
+ Text = "Import Complete";
+
+ IsIndeterminate = false;
+ Value = 100;
+ CheckIfInstalled();
+ DelayedClearProgress(TimeSpan.FromMilliseconds(800));
}
-
- IsIndeterminate = true;
-
- // Save connected model info
- var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Name);
- var modelInfo =
- new ConnectedModelInfo(CivitModel, modelVersion, modelFile, DateTime.UtcNow);
- var modelInfoPath = Path.GetFullPath(Path.Combine(
- downloadFolder, modelFileName + ConnectedModelInfo.FileExtension));
- filesForCleanup.Add(modelInfoPath);
- await modelInfo.SaveJsonToDirectory(downloadFolder, modelFileName);
-
- // If available, save a model image
- if (modelVersion.Images != null && modelVersion.Images.Any())
+ else if (e == ProgressState.Cancelled)
{
- var image = modelVersion.Images[0];
- var imageExtension = Path.GetExtension(image.Url).TrimStart('.');
- if (imageExtension is "jpg" or "jpeg" or "png")
- {
- var imageDownloadPath = Path.GetFullPath(Path.Combine(downloadFolder,
- $"{modelFileName}.preview.{imageExtension}"));
- filesForCleanup.Add(imageDownloadPath);
- var imageTask =
- downloadService.DownloadToFileAsync(image.Url, imageDownloadPath);
- await notificationService.TryAsync(imageTask, "Could not download preview image");
- }
+ Text = "Cancelled";
+ DelayedClearProgress(TimeSpan.FromMilliseconds(500));
}
-
- // Successful - clear cleanup list
- filesForCleanup.Clear();
-
- EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- modelFile.Name, new ProgressReport(1f, "Import complete")));
-
- Text = "Import complete!";
- }
- catch (Exception e)
- {
- Debug.WriteLine(e);
- }
- finally
- {
- foreach (var file in filesForCleanup.Where(file => file.Exists))
+ else if (e == ProgressState.Failed)
{
- file.Delete();
- Logger.Info($"Download cleanup: Deleted file {file}");
+ Text = "Download Failed";
+ DelayedClearProgress(TimeSpan.FromMilliseconds(800));
}
-
- IsIndeterminate = false;
- Value = 100;
- CheckIfInstalled();
- DelayedClearProgress(TimeSpan.FromMilliseconds(800));
- }
+ };
+
+ // Add hash context action
+ download.ContextAction = CivitPostDownloadContextAction.FromCivitFile(modelFile);
+
+ download.Start();
}
private void DelayedClearProgress(TimeSpan delay)
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs
index 53abad57..42bad566 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs
@@ -1,14 +1,19 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
+using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using System.Net.Http;
+using System.Reactive;
+using System.Reactive.Linq;
+using System.Reactive.Threading.Tasks;
+using System.Threading;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Collections;
using Avalonia.Controls;
-using Avalonia.Controls.Notifications;
+using AvaloniaEdit.Utils;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
@@ -17,6 +22,7 @@ using Refit;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
+using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Attributes;
@@ -28,6 +34,7 @@ using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.Settings;
using StabilityMatrix.Core.Services;
+using Notification = Avalonia.Controls.Notifications.Notification;
using Symbol = FluentIcons.Common.Symbol;
using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
@@ -60,10 +67,12 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
[ObservableProperty] private bool hasSearched;
[ObservableProperty] private bool canGoToNextPage;
[ObservableProperty] private bool canGoToPreviousPage;
+ [ObservableProperty] private bool canGoToFirstPage;
+ [ObservableProperty] private bool canGoToLastPage;
[ObservableProperty] private bool isIndeterminate;
[ObservableProperty] private bool noResultsFound;
[ObservableProperty] private string noResultsText = string.Empty;
- [ObservableProperty] private string selectedBaseModelType = "All";
+ [ObservableProperty] private string selectedBaseModelType = "All";
private List allModelCards = new();
@@ -94,6 +103,16 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
CurrentPageNumber = 1;
CanGoToNextPage = true;
+ CanGoToLastPage = true;
+
+ Observable
+ .FromEventPattern(this, nameof(PropertyChanged))
+ .Where(x => x.EventArgs.PropertyName == nameof(CurrentPageNumber))
+ .Throttle(TimeSpan.FromMilliseconds(250))
+ .Select(_ => CurrentPageNumber)
+ .Where(page => page <= TotalPages && page > 0)
+ .ObserveOn(SynchronizationContext.Current)
+ .Subscribe(_ => TrySearchAgain(false).SafeFireAndForget(), err => Logger.Error(err));
}
public override void OnLoaded()
@@ -102,6 +121,14 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
var searchOptions = settingsManager.Settings.ModelSearchOptions;
+ // Fix SelectedModelType if someone had selected the obsolete "Model" option
+ if (searchOptions is {SelectedModelType: CivitModelType.Model})
+ {
+ settingsManager.Transaction(s => s.ModelSearchOptions = new ModelSearchOptions(
+ SelectedPeriod, SortMode, CivitModelType.Checkpoint, SelectedBaseModelType));
+ searchOptions = settingsManager.Settings.ModelSearchOptions;
+ }
+
SelectedPeriod = searchOptions?.SelectedPeriod ?? CivitPeriod.Month;
SortMode = searchOptions?.SortMode ?? CivitSortMode.HighestRated;
SelectedModelType = searchOptions?.SelectedModelType ?? CivitModelType.Checkpoint;
@@ -250,13 +277,21 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
}).ToList();
allModelCards = updateCards;
- ModelCards =
- new ObservableCollection(
- updateCards.Where(FilterModelCardsPredicate));
+
+ var filteredCards = updateCards.Where(FilterModelCardsPredicate);
+ if (SortMode == CivitSortMode.Installed)
+ {
+ filteredCards =
+ filteredCards.OrderByDescending(x => x.UpdateCardText == "Update Available");
+ }
+
+ ModelCards =new ObservableCollection(filteredCards);
}
TotalPages = metadata?.TotalPages ?? 1;
+ CanGoToFirstPage = CurrentPageNumber != 1;
CanGoToPreviousPage = CurrentPageNumber > 1;
CanGoToNextPage = CurrentPageNumber < TotalPages;
+ CanGoToLastPage = CurrentPageNumber != TotalPages;
// Status update
ShowMainLoadingSpinner = false;
IsIndeterminate = false;
@@ -309,6 +344,26 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
{
modelRequest.BaseModel = SelectedBaseModelType;
}
+
+ if (SortMode == CivitSortMode.Installed)
+ {
+ var connectedModels =
+ CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory)
+ .Where(c => c.IsConnectedModel);
+
+ if (SelectedModelType != CivitModelType.All)
+ {
+ connectedModels = connectedModels.Where(c => c.ModelType == SelectedModelType);
+ }
+
+ modelRequest = new CivitModelsRequest
+ {
+ CommaSeparatedModelIds = string.Join(",",
+ connectedModels.Select(c => c.ConnectedModel!.ModelId).GroupBy(m => m)
+ .Select(g => g.First())),
+ Types = SelectedModelType == CivitModelType.All ? null : new[] {SelectedModelType}
+ };
+ }
// See if query is cached
var cachedQuery = await liteDbContext.CivitModelQueryCache
@@ -344,22 +399,32 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase
UpdateResultsText();
}
- [RelayCommand]
- private async Task PreviousPage()
+ public void FirstPage()
{
- if (CurrentPageNumber == 1) return;
+ CurrentPageNumber = 1;
+ }
+ public void PreviousPage()
+ {
+ if (CurrentPageNumber == 1)
+ return;
+
CurrentPageNumber--;
- await TrySearchAgain(false);
}
-
- [RelayCommand]
- private async Task NextPage()
+
+ public void NextPage()
{
+ if (CurrentPageNumber == TotalPages)
+ return;
+
CurrentPageNumber++;
- await TrySearchAgain(false);
}
-
+
+ public void LastPage()
+ {
+ CurrentPageNumber = TotalPages;
+ }
+
partial void OnShowNsfwChanged(bool value)
{
settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled, value);
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs
index e540599c..8ecfff45 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFile.cs
@@ -14,6 +14,7 @@ using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
@@ -46,6 +47,7 @@ public partial class CheckpointFile : ViewModelBase
public bool IsConnectedModel => ConnectedModel != null;
[ObservableProperty] private bool isLoading;
+ [ObservableProperty] private CivitModelType modelType;
public string FileName => Path.GetFileName((string?) FilePath);
@@ -196,45 +198,63 @@ public partial class CheckpointFile : ViewModelBase
///
public static IEnumerable FromDirectoryIndex(string directory, SearchOption searchOption = SearchOption.TopDirectoryOnly)
{
- // Get all files with supported extensions
- var allExtensions = SupportedCheckpointExtensions
- .Concat(SupportedImageExtensions)
- .Concat(SupportedMetadataExtensions);
-
- var files = allExtensions.AsParallel()
- .SelectMany(pattern => Directory.EnumerateFiles(directory, $"*{pattern}", searchOption)).ToDictionary(Path.GetFileName);
-
- foreach (var file in files.Keys.Where(k => SupportedCheckpointExtensions.Contains(Path.GetExtension(k))))
+ foreach (var file in Directory.EnumerateFiles(directory, "*.*", searchOption))
{
- var checkpointFile = new CheckpointFile()
+ if (!SupportedCheckpointExtensions.Any(ext => file.Contains(ext)))
+ continue;
+
+ var checkpointFile = new CheckpointFile
{
Title = Path.GetFileNameWithoutExtension(file),
FilePath = Path.Combine(directory, file),
};
- // Check for connected model info
- var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(file);
- var cmInfoPath = $"{fileNameWithoutExtension}.cm-info.json";
- if (files.TryGetValue(cmInfoPath, out var jsonPath))
+ var jsonPath = Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(file)}.cm-info.json");
+ if (File.Exists(jsonPath))
{
- try
- {
- var jsonData = File.ReadAllText(jsonPath);
- checkpointFile.ConnectedModel = ConnectedModelInfo.FromJson(jsonData);
- }
- catch (IOException e)
- {
- Debug.WriteLine($"Failed to parse {cmInfoPath}: {e}");
- }
+ var json = File.ReadAllText(jsonPath);
+ var connectedModelInfo = ConnectedModelInfo.FromJson(json);
+ checkpointFile.ConnectedModel = connectedModelInfo;
}
- // Check for preview image
- var previewImage = SupportedImageExtensions.Select(ext => $"{fileNameWithoutExtension}.preview{ext}").FirstOrDefault(files.ContainsKey);
- if (previewImage != null)
+ checkpointFile.PreviewImagePath = SupportedImageExtensions
+ .Select(ext => Path.Combine(directory,
+ $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}")).Where(File.Exists)
+ .FirstOrDefault();
+
+ yield return checkpointFile;
+ }
+ }
+
+ public static IEnumerable GetAllCheckpointFiles(string modelsDirectory)
+ {
+ foreach (var file in Directory.EnumerateFiles(modelsDirectory, "*.*", SearchOption.AllDirectories))
+ {
+ if (!SupportedCheckpointExtensions.Any(ext => file.Contains(ext)))
+ continue;
+
+ var checkpointFile = new CheckpointFile
+ {
+ Title = Path.GetFileNameWithoutExtension(file),
+ FilePath = file,
+ };
+
+ var jsonPath = Path.Combine(Path.GetDirectoryName(file),
+ Path.GetFileNameWithoutExtension(file) + ".cm-info.json");
+
+ if (File.Exists(jsonPath))
{
- checkpointFile.PreviewImagePath = files[previewImage];
+ var json = File.ReadAllText(jsonPath);
+ var connectedModelInfo = ConnectedModelInfo.FromJson(json);
+ checkpointFile.ConnectedModel = connectedModelInfo;
+ checkpointFile.ModelType = GetCivitModelType(file);
}
+ checkpointFile.PreviewImagePath = SupportedImageExtensions
+ .Select(ext => Path.Combine(Path.GetDirectoryName(file),
+ $"{Path.GetFileNameWithoutExtension(file)}.preview{ext}")).Where(File.Exists)
+ .FirstOrDefault();
+
yield return checkpointFile;
}
}
@@ -253,4 +273,39 @@ public partial class CheckpointFile : ViewModelBase
yield return checkpointFile;
}
}
+
+ private static CivitModelType GetCivitModelType(string filePath)
+ {
+ if (filePath.Contains(SharedFolderType.StableDiffusion.ToString()))
+ {
+ return CivitModelType.Checkpoint;
+ }
+
+ if (filePath.Contains(SharedFolderType.ControlNet.ToString()))
+ {
+ return CivitModelType.Controlnet;
+ }
+
+ if (filePath.Contains(SharedFolderType.Lora.ToString()))
+ {
+ return CivitModelType.LORA;
+ }
+
+ if (filePath.Contains(SharedFolderType.TextualInversion.ToString()))
+ {
+ return CivitModelType.TextualInversion;
+ }
+
+ if (filePath.Contains(SharedFolderType.Hypernetwork.ToString()))
+ {
+ return CivitModelType.Hypernetwork;
+ }
+
+ if (filePath.Contains(SharedFolderType.LyCORIS.ToString()))
+ {
+ return CivitModelType.LoCon;
+ }
+
+ return CivitModelType.Unknown;
+ }
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
index 2808d564..eaf8eadb 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFolder.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
+using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
@@ -414,8 +415,7 @@ public partial class CheckpointFolder : ViewModelBase
{
// Create subfolder
var subFolder = new CheckpointFolder(settingsManager,
- downloadService, modelFinder,
- useCategoryVisibility: false)
+ downloadService, modelFinder, useCategoryVisibility: false)
{
Title = Path.GetFileName(folder),
DirectoryPath = folder,
diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
index a5a59945..b3351ebc 100644
--- a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.ObjectModel;
+using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
@@ -10,6 +11,7 @@ using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using NLog;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
@@ -68,45 +70,59 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
this.downloadService = downloadService;
this.modelFinder = modelFinder;
}
-
+
public override async Task OnLoadedAsync()
{
+ var sw = Stopwatch.StartNew();
DisplayedCheckpointFolders = CheckpointFolders;
// Set UI states
IsImportAsConnected = settingsManager.Settings.IsImportAsConnected;
// Refresh search filter
OnSearchFilterChanged(string.Empty);
+
+ Logger.Info($"Loaded {DisplayedCheckpointFolders.Count} checkpoint folders in {sw.ElapsedMilliseconds}ms");
if (Design.IsDesignMode) return;
- await Dispatcher.UIThread.InvokeAsync(async () =>
- {
- IsLoading = CheckpointFolders.Count == 0;
- IsIndexing = CheckpointFolders.Count > 0;
- await IndexFolders();
- IsLoading = false;
- IsIndexing = false;
- });
+ IsLoading = CheckpointFolders.Count == 0;
+ IsIndexing = CheckpointFolders.Count > 0;
+ await IndexFolders();
+ IsLoading = false;
+ IsIndexing = false;
+
+ Logger.Info($"OnLoadedAsync in {sw.ElapsedMilliseconds}ms");
}
// ReSharper disable once UnusedParameterInPartialMethod
partial void OnSearchFilterChanged(string value)
{
+ var sw = Stopwatch.StartNew();
if (string.IsNullOrWhiteSpace(SearchFilter))
{
- DisplayedCheckpointFolders = CheckpointFolders;
+ DisplayedCheckpointFolders = new ObservableCollection(
+ CheckpointFolders.Select(x =>
+ {
+ x.SearchFilter = SearchFilter;
+ return x;
+ }));
+ sw.Stop();
+ Logger.Info($"OnSearchFilterChanged in {sw.ElapsedMilliseconds}ms");
return;
}
+ sw.Restart();
+
var filteredFolders = CheckpointFolders
.Where(ContainsSearchFilter).ToList();
foreach (var folder in filteredFolders)
{
folder.SearchFilter = SearchFilter;
}
+ sw.Stop();
+ Logger.Info($"ContainsSearchFilter in {sw.ElapsedMilliseconds}ms");
- DisplayedCheckpointFolders = new ObservableCollection(filteredFolders);
+ DisplayedCheckpointFolders = new ObservableCollection(filteredFolders);
}
private bool ContainsSearchFilter(CheckpointManager.CheckpointFolder folder)
@@ -136,11 +152,13 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
var folders = Directory.GetDirectories(modelsDirectory);
+ var sw = Stopwatch.StartNew();
+
// Index all folders
- var indexTasks = folders.Select(f => Task.Run(async () =>
+ var indexTasks = folders.Select(async f =>
{
var checkpointFolder =
- new CheckpointManager.CheckpointFolder(settingsManager, downloadService, modelFinder)
+ new CheckpointFolder(settingsManager, downloadService, modelFinder)
{
Title = Path.GetFileName(f),
DirectoryPath = f,
@@ -148,21 +166,29 @@ public partial class CheckpointsPageViewModel : PageViewModelBase
};
await checkpointFolder.IndexAsync();
return checkpointFolder;
- })).ToList();
+ }).ToList();
await Task.WhenAll(indexTasks);
+
+ sw.Stop();
+ Logger.Info($"IndexFolders in {sw.ElapsedMilliseconds}ms");
// Set new observable collection, ordered by alphabetical order
CheckpointFolders =
- new ObservableCollection(indexTasks
+ new ObservableCollection(indexTasks
.Select(t => t.Result)
.OrderBy(f => f.Title));
-
+
if (!string.IsNullOrWhiteSpace(SearchFilter))
{
- DisplayedCheckpointFolders = new ObservableCollection(
- CheckpointFolders
- .Where(x => x.CheckpointFiles.Any(y => y.FileName.Contains(SearchFilter))));
+ var filtered = CheckpointFolders
+ .Where(x => x.CheckpointFiles.Any(y => y.FileName.Contains(SearchFilter))).Select(
+ f =>
+ {
+ f.SearchFilter = SearchFilter;
+ return f;
+ });
+ DisplayedCheckpointFolders = new ObservableCollection(filtered);
}
else
{
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
index 209db2a1..d9f99a7b 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
@@ -199,14 +199,14 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
version = SelectedVersion?.TagName ??
throw new NullReferenceException("Selected version is null");
- await DownloadPackage(version, false);
+ await DownloadPackage(version, false, null);
}
else
{
version = SelectedCommit?.Sha ??
throw new NullReferenceException("Selected commit is null");
- await DownloadPackage(version, true);
+ await DownloadPackage(version, true, SelectedVersion!.TagName);
}
await InstallPackage();
@@ -271,7 +271,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
return branch == null ? version : $"{branch}@{version[..7]}";
}
- private Task DownloadPackage(string version, bool isCommitHash)
+ private Task DownloadPackage(string version, bool isCommitHash, string? branch)
{
InstallProgress.Text = "Downloading package...";
@@ -282,7 +282,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
EventManager.Instance.OnGlobalProgressChanged((int) progress.Percentage);
});
- return SelectedPackage.DownloadPackage(version, isCommitHash, progress);
+ return SelectedPackage.DownloadPackage(version, isCommitHash, branch, progress);
}
private async Task InstallPackage()
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs
index 519fe8c1..9979fd6f 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LaunchOptionsViewModel.cs
@@ -1,9 +1,13 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
+using System.ComponentModel;
using System.Linq;
+using System.Reactive.Linq;
+using System.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using FuzzySharp;
+using Microsoft.Extensions.Logging;
using StabilityMatrix.Avalonia.Views.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper.Cache;
@@ -14,26 +18,28 @@ namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
[View(typeof(LaunchOptionsDialog))]
public partial class LaunchOptionsViewModel : ContentDialogViewModelBase
{
+ private readonly ILogger logger;
private readonly LRUCache> cache = new(100);
-
- [ObservableProperty] private string title = "Launch Options";
- [ObservableProperty] private bool isSearchBoxEnabled = true;
-
+
+ [ObservableProperty]
+ private string title = "Launch Options";
+
+ [ObservableProperty]
+ private bool isSearchBoxEnabled = true;
+
[ObservableProperty]
- [NotifyPropertyChangedFor(nameof(FilteredCards))]
private string searchText = string.Empty;
-
- [ObservableProperty]
+
+ [ObservableProperty]
private IReadOnlyList? filteredCards;
-
+
public IReadOnlyList? Cards { get; set; }
-
+
///
/// Return cards that match the search text
///
- private IReadOnlyList? GetFilteredCards()
+ private IReadOnlyList? GetFilteredCards(string? text)
{
- var text = SearchText;
if (string.IsNullOrWhiteSpace(text) || text.Length < 2)
{
return Cards;
@@ -50,18 +56,30 @@ public partial class LaunchOptionsViewModel : ContentDialogViewModelBase
Type = LaunchOptionType.Bool,
Options = Array.Empty()
};
-
- var extracted = Process
- .ExtractTop(searchCard, Cards, c => c.Title.ToLowerInvariant());
- var results = extracted
- .Where(r => r.Score > 40)
- .Select(r => r.Value)
- .ToImmutableList();
+
+ var extracted = Process.ExtractTop(searchCard, Cards, c => c.Title.ToLowerInvariant());
+ var results = extracted.Where(r => r.Score > 40).Select(r => r.Value).ToImmutableList();
cache.Add(text, results);
return results;
}
- public void UpdateFilterCards() => FilteredCards = GetFilteredCards();
+ public void UpdateFilterCards() => FilteredCards = GetFilteredCards(SearchText);
+
+ public LaunchOptionsViewModel(ILogger logger)
+ {
+ this.logger = logger;
+
+ Observable
+ .FromEventPattern(this, nameof(PropertyChanged))
+ .Where(x => x.EventArgs.PropertyName == nameof(SearchText))
+ .Throttle(TimeSpan.FromMilliseconds(50))
+ .Select(_ => SearchText)
+ .ObserveOn(SynchronizationContext.Current!)
+ .Subscribe(
+ text => FilteredCards = GetFilteredCards(text),
+ err => logger.LogError(err, "Error while filtering launch options")
+ );
+ }
public override void OnLoaded()
{
@@ -75,8 +93,9 @@ public partial class LaunchOptionsViewModel : ContentDialogViewModelBase
public List AsLaunchArgs()
{
var launchArgs = new List();
- if (Cards is null) return launchArgs;
-
+ if (Cards is null)
+ return launchArgs;
+
foreach (var card in Cards)
{
launchArgs.AddRange(card.Options);
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
index 854797e4..e2ff39cb 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.ObjectModel;
using System.IO;
+using System.Linq;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
@@ -55,7 +56,8 @@ public partial class OneClickInstallViewModel : ViewModelBase
SubHeaderText = "Choose your preferred interface and click Install to get started!";
ShowInstallButton = true;
AllPackages =
- new ObservableCollection(this.packageFactory.GetAllAvailablePackages());
+ new ObservableCollection(this.packageFactory.GetAllAvailablePackages()
+ .Where(p => p.OfferInOneClickInstaller));
SelectedPackage = AllPackages[0];
}
@@ -157,7 +159,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress);
});
- await SelectedPackage.DownloadPackage(version, false, progress);
+ await SelectedPackage.DownloadPackage(version, false, version, progress);
SubHeaderText = "Download Complete";
OneClickInstallProgress = 100;
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs
new file mode 100644
index 00000000..40ed6a88
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs
@@ -0,0 +1,221 @@
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Collections.ObjectModel;
+using System.IO;
+using System.Linq;
+using System.Threading.Tasks;
+using AsyncAwaitBestPractices;
+using Avalonia.Controls;
+using Avalonia.Threading;
+using CommunityToolkit.Mvvm.ComponentModel;
+using NLog;
+using StabilityMatrix.Avalonia.Views.Dialogs;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Helper.Factory;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.Database;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Packages;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
+
+[View(typeof(PackageImportDialog))]
+public partial class PackageImportViewModel : ContentDialogViewModelBase
+{
+ private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
+
+ private readonly IPackageFactory packageFactory;
+ private readonly ISettingsManager settingsManager;
+
+ [ObservableProperty] private DirectoryPath? packagePath;
+ [ObservableProperty] private BasePackage? selectedBasePackage;
+
+ public IReadOnlyList AvailablePackages
+ => packageFactory.GetAllAvailablePackages().ToImmutableArray();
+
+ [ObservableProperty] private PackageVersion? selectedVersion;
+
+ [ObservableProperty] private ObservableCollection? availableCommits;
+ [ObservableProperty] private ObservableCollection? availableVersions;
+
+ [ObservableProperty] private GitCommit? selectedCommit;
+
+ // Version types (release or commit)
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(ReleaseLabelText),
+ nameof(IsReleaseMode), nameof(SelectedVersion))]
+ private PackageVersionType selectedVersionType = PackageVersionType.Commit;
+
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(IsReleaseModeAvailable))]
+ private PackageVersionType availableVersionTypes =
+ PackageVersionType.GithubRelease | PackageVersionType.Commit;
+ public string ReleaseLabelText => IsReleaseMode ? "Version" : "Branch";
+ public bool IsReleaseMode
+ {
+ get => SelectedVersionType == PackageVersionType.GithubRelease;
+ set => SelectedVersionType = value ? PackageVersionType.GithubRelease : PackageVersionType.Commit;
+ }
+
+ public bool IsReleaseModeAvailable => AvailableVersionTypes.HasFlag(PackageVersionType.GithubRelease);
+
+ public PackageImportViewModel(
+ IPackageFactory packageFactory,
+ ISettingsManager settingsManager)
+ {
+ this.packageFactory = packageFactory;
+ this.settingsManager = settingsManager;
+ }
+
+ public override async Task OnLoadedAsync()
+ {
+ SelectedBasePackage ??= AvailablePackages[0];
+
+ if (Design.IsDesignMode) return;
+ // Populate available versions
+ try
+ {
+ if (IsReleaseMode)
+ {
+ var versions = (await SelectedBasePackage.GetAllVersions()).ToList();
+ AvailableVersions = new ObservableCollection(versions);
+ if (!AvailableVersions.Any()) return;
+
+ SelectedVersion = AvailableVersions[0];
+ }
+ else
+ {
+ var branches = (await SelectedBasePackage.GetAllBranches()).ToList();
+ AvailableVersions = new ObservableCollection(branches.Select(b =>
+ new PackageVersion
+ {
+ TagName = b.Name,
+ ReleaseNotesMarkdown = b.Commit.Label
+ }));
+ UpdateSelectedVersionToLatestMain();
+ }
+ }
+ catch (Exception e)
+ {
+ Logger.Warn("Error getting versions: {Exception}", e.ToString());
+ }
+ }
+
+ private static string GetDisplayVersion(string version, string? branch)
+ {
+ return branch == null ? version : $"{branch}@{version[..7]}";
+ }
+
+ // When available version types change, reset selected version type if not compatible
+ partial void OnAvailableVersionTypesChanged(PackageVersionType value)
+ {
+ if (!value.HasFlag(SelectedVersionType))
+ {
+ SelectedVersionType = value;
+ }
+ }
+
+ // When changing branch / release modes, refresh
+ // ReSharper disable once UnusedParameterInPartialMethod
+ partial void OnSelectedVersionTypeChanged(PackageVersionType value)
+ => OnSelectedBasePackageChanged(SelectedBasePackage);
+
+ partial void OnSelectedBasePackageChanged(BasePackage? value)
+ {
+ if (value is null || SelectedBasePackage is null)
+ {
+ AvailableVersions?.Clear();
+ AvailableCommits?.Clear();
+ return;
+ }
+
+ AvailableVersions?.Clear();
+ AvailableCommits?.Clear();
+
+ AvailableVersionTypes = SelectedBasePackage.ShouldIgnoreReleases
+ ? PackageVersionType.Commit
+ : PackageVersionType.GithubRelease | PackageVersionType.Commit;
+
+ if (Design.IsDesignMode) return;
+
+ Dispatcher.UIThread.InvokeAsync(async () =>
+ {
+ Logger.Debug($"Release mode: {IsReleaseMode}");
+ var versions = (await value.GetAllVersions(IsReleaseMode)).ToList();
+
+ if (!versions.Any()) return;
+
+ AvailableVersions = new ObservableCollection(versions);
+ Logger.Debug($"Available versions: {string.Join(", ", AvailableVersions)}");
+ SelectedVersion = AvailableVersions[0];
+
+ if (!IsReleaseMode)
+ {
+ var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList();
+ if (commits is null || commits.Count == 0) return;
+
+ AvailableCommits = new ObservableCollection(commits);
+ SelectedCommit = AvailableCommits[0];
+ UpdateSelectedVersionToLatestMain();
+ }
+ }).SafeFireAndForget();
+ }
+
+ private void UpdateSelectedVersionToLatestMain()
+ {
+ if (AvailableVersions is null)
+ {
+ SelectedVersion = null;
+ }
+ else
+ {
+ // First try to find master
+ var version = AvailableVersions.FirstOrDefault(x => x.TagName == "master");
+ // If not found, try main
+ version ??= AvailableVersions.FirstOrDefault(x => x.TagName == "main");
+
+ // If still not found, just use the first one
+ version ??= AvailableVersions[0];
+
+ SelectedVersion = version;
+ }
+ }
+
+ public void AddPackageWithCurrentInputs()
+ {
+ if (SelectedBasePackage is null || PackagePath is null)
+ return;
+
+ string version;
+ if (IsReleaseMode)
+ {
+ version = SelectedVersion?.TagName ??
+ throw new NullReferenceException("Selected version is null");
+ }
+ else
+ {
+ version = SelectedCommit?.Sha ??
+ throw new NullReferenceException("Selected commit is null");
+ }
+
+ var branch = SelectedVersionType == PackageVersionType.GithubRelease ?
+ null : SelectedVersion!.TagName;
+
+ var package = new InstalledPackage
+ {
+ Id = Guid.NewGuid(),
+ DisplayName = PackagePath.Name,
+ PackageName = SelectedBasePackage.Name,
+ LibraryPath = $"Packages{Path.DirectorySeparatorChar}{PackagePath.Name}",
+ PackageVersion = version,
+ DisplayVersion = GetDisplayVersion(version, branch),
+ InstalledBranch = branch,
+ LaunchCommand = SelectedBasePackage.LaunchCommand,
+ LastUpdateCheck = DateTimeOffset.Now,
+ };
+
+ settingsManager.Transaction(s => s.InstalledPackages.Add(package));
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
index d6435c85..082859db 100644
--- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
@@ -42,8 +42,11 @@ public partial class SelectModelVersionViewModel : ContentDialogViewModelBase
var firstImageUrl = value?.ModelVersion?.Images?.FirstOrDefault(
img => nsfwEnabled || img.Nsfw == "None")?.Url;
- Dispatcher.UIThread.InvokeAsync(async
- () => await UpdateImage(firstImageUrl));
+ Dispatcher.UIThread.InvokeAsync(async () =>
+ {
+ SelectedFile = value?.CivitFileViewModels.FirstOrDefault();
+ await UpdateImage(firstImageUrl);
+ });
}
partial void OnSelectedFileChanged(CivitFileViewModel? value)
diff --git a/StabilityMatrix.Avalonia/ViewModels/DownloadProgressItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/DownloadProgressItemViewModel.cs
new file mode 100644
index 00000000..7e5fee85
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/DownloadProgressItemViewModel.cs
@@ -0,0 +1,86 @@
+using System;
+using System.Threading.Tasks;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.Progress;
+
+namespace StabilityMatrix.Avalonia.ViewModels;
+
+public class DownloadProgressItemViewModel : PausableProgressItemViewModelBase
+{
+ private readonly TrackedDownload download;
+
+ public DownloadProgressItemViewModel(TrackedDownload download)
+ {
+ this.download = download;
+
+ Id = download.Id;
+ Name = download.FileName;
+ State = download.ProgressState;
+ OnProgressStateChanged(State);
+
+ // If initial progress provided, load it
+ if (download is {TotalBytes: > 0, DownloadedBytes: > 0})
+ {
+ var current = download.DownloadedBytes / (double) download.TotalBytes;
+ Progress.Value = (float) Math.Ceiling(Math.Clamp(current, 0, 1) * 100);
+ }
+
+ download.ProgressUpdate += (s, e) =>
+ {
+ Progress.Value = e.Percentage;
+ Progress.IsIndeterminate = e.IsIndeterminate;
+ };
+
+ download.ProgressStateChanged += (s, e) =>
+ {
+ State = e;
+ OnProgressStateChanged(e);
+ };
+ }
+
+ private void OnProgressStateChanged(ProgressState state)
+ {
+ if (state == ProgressState.Inactive)
+ {
+ Progress.Text = "Paused";
+ }
+ else if (state == ProgressState.Working)
+ {
+ Progress.Text = "Downloading...";
+ }
+ else if (state == ProgressState.Success)
+ {
+ Progress.Text = "Completed";
+ }
+ else if (state == ProgressState.Cancelled)
+ {
+ Progress.Text = "Cancelled";
+ }
+ else if (state == ProgressState.Failed)
+ {
+ Progress.Text = "Failed";
+ }
+ }
+
+ ///
+ public override Task Cancel()
+ {
+ download.Cancel();
+ return Task.CompletedTask;
+ }
+
+ ///
+ public override Task Pause()
+ {
+ download.Pause();
+ return Task.CompletedTask;
+ }
+
+ ///
+ public override Task Resume()
+ {
+ download.Resume();
+ return Task.CompletedTask;
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
index 355ed12c..1b8d3131 100644
--- a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
@@ -33,7 +33,6 @@ using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
-using Notification = Avalonia.Controls.Notifications.Notification;
using Symbol = FluentIcons.Common.Symbol;
using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
@@ -305,13 +304,16 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
IsPrimaryButtonEnabled = true,
PrimaryButtonText = "Save",
CloseButtonText = "Cancel",
+ FullSizeDesired = true,
DefaultButton = ContentDialogButton.Primary,
+ ContentMargin = new Thickness(32, 16),
Padding = new Thickness(0, 16),
Content = new LaunchOptionsDialog
{
DataContext = viewModel,
}
};
+
var result = await dialog.ShowAsync();
if (result == ContentDialogResult.Primary)
@@ -448,7 +450,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
private void RunningPackageOnStartupComplete(object? sender, string e)
{
- webUiUrl = e;
+ webUiUrl = e.Replace("0.0.0.0", "127.0.0.1");
ShowWebUiButton = !string.IsNullOrWhiteSpace(webUiUrl);
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs
index 7ab02651..66bb1a6d 100644
--- a/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs
@@ -24,6 +24,7 @@ public partial class MainWindowViewModel : ViewModelBase
{
private readonly ISettingsManager settingsManager;
private readonly ServiceManager dialogFactory;
+ private readonly ITrackedDownloadService trackedDownloadService;
private readonly IDiscordRichPresenceService discordRichPresenceService;
public string Greeting => "Welcome to Avalonia!";
@@ -45,11 +46,13 @@ public partial class MainWindowViewModel : ViewModelBase
public MainWindowViewModel(
ISettingsManager settingsManager,
IDiscordRichPresenceService discordRichPresenceService,
- ServiceManager dialogFactory)
+ ServiceManager dialogFactory,
+ ITrackedDownloadService trackedDownloadService)
{
this.settingsManager = settingsManager;
this.dialogFactory = dialogFactory;
this.discordRichPresenceService = discordRichPresenceService;
+ this.trackedDownloadService = trackedDownloadService;
ProgressManagerViewModel = dialogFactory.Get();
UpdateViewModel = dialogFactory.Get();
@@ -81,6 +84,9 @@ public partial class MainWindowViewModel : ViewModelBase
// Initialize Discord Rich Presence (this needs LibraryDir so is set here)
discordRichPresenceService.UpdateState();
+ // Load in-progress downloads
+ ProgressManagerViewModel.AddDownloads(trackedDownloadService.Downloads);
+
// Index checkpoints if we dont have
Task.Run(() => settingsManager.IndexCheckpoints()).SafeFireAndForget();
diff --git a/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs
new file mode 100644
index 00000000..b165aef0
--- /dev/null
+++ b/StabilityMatrix.Avalonia/ViewModels/NewCheckpointsPageViewModel.cs
@@ -0,0 +1,221 @@
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Collections.ObjectModel;
+using System.IO;
+using System.Linq;
+using System.Net.Http;
+using System.Threading.Tasks;
+using AsyncAwaitBestPractices;
+using Avalonia.Controls;
+using Avalonia.Controls.Notifications;
+using AvaloniaEdit.Utils;
+using CommunityToolkit.Mvvm.ComponentModel;
+using FluentAvalonia.UI.Controls;
+using Microsoft.Extensions.Logging;
+using Refit;
+using StabilityMatrix.Avalonia.Controls;
+using StabilityMatrix.Avalonia.Services;
+using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Avalonia.ViewModels.CheckpointManager;
+using StabilityMatrix.Avalonia.ViewModels.Dialogs;
+using StabilityMatrix.Avalonia.Views;
+using StabilityMatrix.Avalonia.Views.Dialogs;
+using StabilityMatrix.Core.Api;
+using StabilityMatrix.Core.Attributes;
+using StabilityMatrix.Core.Database;
+using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.Api;
+using StabilityMatrix.Core.Services;
+using Symbol = FluentIcons.Common.Symbol;
+using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
+
+namespace StabilityMatrix.Avalonia.ViewModels;
+
+[View(typeof(NewCheckpointsPage))]
+public partial class NewCheckpointsPageViewModel : PageViewModelBase
+{
+ private readonly ILogger logger;
+ private readonly ISettingsManager settingsManager;
+ private readonly ILiteDbContext liteDbContext;
+ private readonly ICivitApi civitApi;
+ private readonly ServiceManager dialogFactory;
+ private readonly INotificationService notificationService;
+ public override string Title => "Checkpoint Manager";
+ public override IconSource IconSource => new SymbolIconSource
+ {Symbol = Symbol.Cellular5g, IsFilled = true};
+
+ public NewCheckpointsPageViewModel(ILogger logger,
+ ISettingsManager settingsManager, ILiteDbContext liteDbContext, ICivitApi civitApi,
+ ServiceManager dialogFactory, INotificationService notificationService)
+ {
+ this.logger = logger;
+ this.settingsManager = settingsManager;
+ this.liteDbContext = liteDbContext;
+ this.civitApi = civitApi;
+ this.dialogFactory = dialogFactory;
+ this.notificationService = notificationService;
+ }
+
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(ConnectedCheckpoints))]
+ [NotifyPropertyChangedFor(nameof(NonConnectedCheckpoints))]
+ private ObservableCollection allCheckpoints = new();
+
+ [ObservableProperty]
+ private ObservableCollection civitModels = new();
+
+ public ObservableCollection ConnectedCheckpoints => new(
+ AllCheckpoints.Where(x => x.IsConnectedModel)
+ .OrderBy(x => x.ConnectedModel!.ModelName)
+ .ThenBy(x => x.ModelType)
+ .GroupBy(x => x.ConnectedModel!.ModelId)
+ .Select(x => x.First()));
+
+ public ObservableCollection NonConnectedCheckpoints => new(
+ AllCheckpoints.Where(x => !x.IsConnectedModel).OrderBy(x => x.ModelType));
+
+ public override async Task OnLoadedAsync()
+ {
+ if (Design.IsDesignMode) return;
+
+ var files = CheckpointFile.GetAllCheckpointFiles(settingsManager.ModelsDirectory);
+ AllCheckpoints = new ObservableCollection(files);
+
+ var connectedModelIds = ConnectedCheckpoints.Select(x => x.ConnectedModel.ModelId);
+ var modelRequest = new CivitModelsRequest
+ {
+ CommaSeparatedModelIds = string.Join(',', connectedModelIds)
+ };
+
+ // See if query is cached
+ var cachedQuery = await liteDbContext.CivitModelQueryCache
+ .IncludeAll()
+ .FindByIdAsync(ObjectHash.GetMd5Guid(modelRequest));
+
+ // If cached, update model cards
+ if (cachedQuery is not null)
+ {
+ CivitModels = new ObservableCollection(cachedQuery.Items);
+
+ // Start remote query (background mode)
+ // Skip when last query was less than 2 min ago
+ var timeSinceCache = DateTimeOffset.UtcNow - cachedQuery.InsertedAt;
+ if (timeSinceCache?.TotalMinutes >= 2)
+ {
+ CivitQuery(modelRequest).SafeFireAndForget();
+ }
+ }
+ else
+ {
+ await CivitQuery(modelRequest);
+ }
+ }
+
+ public async Task ShowVersionDialog(int modelId)
+ {
+ var model = CivitModels.FirstOrDefault(m => m.Id == modelId);
+ if (model == null)
+ {
+ notificationService.Show(new Notification("Model has no versions available",
+ "This model has no versions available for download", NotificationType.Warning));
+ return;
+ }
+ var versions = model.ModelVersions;
+ if (versions is null || versions.Count == 0)
+ {
+ notificationService.Show(new Notification("Model has no versions available",
+ "This model has no versions available for download", NotificationType.Warning));
+ return;
+ }
+
+ var dialog = new BetterContentDialog
+ {
+ Title = model.Name,
+ IsPrimaryButtonEnabled = false,
+ IsSecondaryButtonEnabled = false,
+ IsFooterVisible = false,
+ MaxDialogWidth = 750,
+ };
+
+ var viewModel = dialogFactory.Get();
+ viewModel.Dialog = dialog;
+ viewModel.Versions = versions.Select(version =>
+ new ModelVersionViewModel(
+ settingsManager.Settings.InstalledModelHashes ?? new HashSet(), version))
+ .ToImmutableArray();
+ viewModel.SelectedVersionViewModel = viewModel.Versions[0];
+
+ dialog.Content = new SelectModelVersionDialog
+ {
+ DataContext = viewModel
+ };
+
+ var result = await dialog.ShowAsync();
+
+ if (result != ContentDialogResult.Primary)
+ {
+ return;
+ }
+
+ var selectedVersion = viewModel?.SelectedVersionViewModel?.ModelVersion;
+ var selectedFile = viewModel?.SelectedFile?.CivitFile;
+ }
+
+ private async Task CivitQuery(CivitModelsRequest request)
+ {
+ try
+ {
+ var modelResponse = await civitApi.GetModels(request);
+ var models = modelResponse.Items;
+ // Filter out unknown model types and archived/taken-down models
+ models = models.Where(m => m.Type.ConvertTo() > 0)
+ .Where(m => m.Mode == null).ToList();
+
+ // Database update calls will invoke `OnModelsUpdated`
+ // Add to database
+ await liteDbContext.UpsertCivitModelAsync(models);
+ // Add as cache entry
+ var cacheNew = await liteDbContext.UpsertCivitModelQueryCacheEntryAsync(
+ new CivitModelQueryCacheEntry
+ {
+ Id = ObjectHash.GetMd5Guid(request),
+ InsertedAt = DateTimeOffset.UtcNow,
+ Request = request,
+ Items = models,
+ Metadata = modelResponse.Metadata
+ });
+
+ if (cacheNew)
+ {
+ CivitModels = new ObservableCollection(models);
+ }
+ }
+ catch (OperationCanceledException)
+ {
+ notificationService.Show(new Notification("Request to CivitAI timed out",
+ "Could not check for checkpoint updates. Please try again later."));
+ logger.LogWarning($"CivitAI query timed out ({request})");
+ }
+ catch (HttpRequestException e)
+ {
+ notificationService.Show(new Notification("CivitAI can't be reached right now",
+ "Could not check for checkpoint updates. Please try again later."));
+ logger.LogWarning(e, $"CivitAI query HttpRequestException ({request})");
+ }
+ catch (ApiException e)
+ {
+ notificationService.Show(new Notification("CivitAI can't be reached right now",
+ "Could not check for checkpoint updates. Please try again later."));
+ logger.LogWarning(e, $"CivitAI query ApiException ({request})");
+ }
+ catch (Exception e)
+ {
+ notificationService.Show(new Notification("CivitAI can't be reached right now",
+ $"Unknown exception during CivitAI query: {e.GetType().Name}"));
+ logger.LogError(e, $"CivitAI query unknown exception ({request})");
+ }
+ }
+}
diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
index 5c154c99..d8b538e9 100644
--- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
@@ -1,20 +1,22 @@
using System;
-using System.IO;
-using System.Linq;
using System.Threading.Tasks;
+using Avalonia.Controls;
using Avalonia.Controls.Notifications;
using CommunityToolkit.Mvvm.ComponentModel;
-using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
-using NLog;
-using Polly;
+using Microsoft.Extensions.Logging;
using StabilityMatrix.Avalonia.Animations;
+using StabilityMatrix.Avalonia.Extensions;
+using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
+using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory;
using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Services;
@@ -23,27 +25,33 @@ namespace StabilityMatrix.Avalonia.ViewModels.PackageManager;
public partial class PackageCardViewModel : ProgressViewModel
{
+ private readonly ILogger logger;
private readonly IPackageFactory packageFactory;
private readonly INotificationService notificationService;
private readonly ISettingsManager settingsManager;
private readonly INavigationService navigationService;
- private readonly Logger logger = LogManager.GetCurrentClassLogger();
+ private readonly ServiceManager vmFactory;
[ObservableProperty] private InstalledPackage? package;
- [ObservableProperty] private Uri cardImage;
+ [ObservableProperty] private string? cardImageSource;
[ObservableProperty] private bool isUpdateAvailable;
- [ObservableProperty] private string installedVersion;
-
+ [ObservableProperty] private string? installedVersion;
+ [ObservableProperty] private bool isUnknownPackage;
+
public PackageCardViewModel(
+ ILogger logger,
IPackageFactory packageFactory,
INotificationService notificationService,
ISettingsManager settingsManager,
- INavigationService navigationService)
+ INavigationService navigationService,
+ ServiceManager vmFactory)
{
+ this.logger = logger;
this.packageFactory = packageFactory;
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.navigationService = navigationService;
+ this.vmFactory = vmFactory;
}
partial void OnPackageChanged(InstalledPackage? value)
@@ -51,9 +59,21 @@ public partial class PackageCardViewModel : ProgressViewModel
if (string.IsNullOrWhiteSpace(value?.PackageName))
return;
- var basePackage = packageFactory[value.PackageName];
- CardImage = basePackage?.PreviewImageUri ?? Assets.NoImage;
- InstalledVersion = value.DisplayVersion ?? "Unknown";
+ if (value.PackageName == UnknownPackage.Key)
+ {
+ IsUnknownPackage = true;
+ CardImageSource = "";
+ InstalledVersion = "Unknown";
+ }
+ else
+ {
+ IsUnknownPackage = false;
+
+ var basePackage = packageFactory[value.PackageName];
+ CardImageSource = basePackage?.PreviewImageUri.ToString()
+ ?? Assets.NoImage.ToString();
+ InstalledVersion = value.DisplayVersion ?? "Unknown";
+ }
}
public override async Task OnLoadedAsync()
@@ -94,9 +114,10 @@ public partial class PackageCardViewModel : ProgressViewModel
Text = "Uninstalling...";
IsIndeterminate = true;
Value = -1;
-
- var deleteTask = DeleteDirectoryAsync(Path.Combine(settingsManager.LibraryDir,
- Package.LibraryPath));
+
+ var packagePath = new DirectoryPath(settingsManager.LibraryDir, Package.LibraryPath);
+ var deleteTask = packagePath.DeleteVerboseAsync(logger);
+
var taskResult = await notificationService.TryAsync(deleteTask,
"Some files could not be deleted. Please close any open files in the package directory and try again.");
if (taskResult.IsSuccessful)
@@ -104,11 +125,14 @@ public partial class PackageCardViewModel : ProgressViewModel
notificationService.Show(new Notification("Success",
$"Package {Package.DisplayName} uninstalled",
NotificationType.Success));
-
- settingsManager.Transaction(settings =>
+
+ if (!IsUnknownPackage)
{
- settings.RemoveInstalledPackageAndUpdateActive(Package);
- });
+ settingsManager.Transaction(settings =>
+ {
+ settings.RemoveInstalledPackageAndUpdateActive(Package);
+ });
+ }
EventManager.Instance.OnInstalledPackagesChanged();
}
@@ -117,25 +141,27 @@ public partial class PackageCardViewModel : ProgressViewModel
public async Task Update()
{
- if (Package == null) return;
+ if (Package is null || IsUnknownPackage) return;
var basePackage = packageFactory[Package.PackageName!];
if (basePackage == null)
{
- logger.Warn("Could not find package {SelectedPackagePackageName}",
+ logger.LogWarning("Could not find package {SelectedPackagePackageName}",
Package.PackageName);
notificationService.Show("Invalid Package type",
$"Package {Package.PackageName.ToRepr()} is not a valid package type",
NotificationType.Error);
return;
}
+
+ var packageName = Package.DisplayName ?? Package.PackageName ?? "";
- Text = $"Updating {Package.DisplayName}";
+ Text = $"Updating {packageName}";
IsIndeterminate = true;
var progressId = Guid.NewGuid();
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- Package.DisplayName,
+ Package.DisplayName ?? Package.PackageName!,
new ProgressReport(0f, isIndeterminate: true, type: ProgressType.Update)));
try
@@ -152,7 +178,7 @@ public partial class PackageCardViewModel : ProgressViewModel
EventManager.Instance.OnGlobalProgressChanged(percent);
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- Package.DisplayName, progress));
+ packageName, progress));
});
var updateResult = await basePackage.Update(Package, progress);
@@ -170,15 +196,15 @@ public partial class PackageCardViewModel : ProgressViewModel
InstalledVersion = Package.DisplayVersion ?? "Unknown";
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- Package.DisplayName,
+ packageName,
new ProgressReport(1f, "Update complete", type: ProgressType.Update)));
}
catch (Exception e)
{
- logger.Error(e, "Error Updating Package ({PackageName})", basePackage.Name);
+ logger.LogError(e, "Error Updating Package ({PackageName})", basePackage.Name);
notificationService.ShowPersistent($"Error Updating {Package.DisplayName}", e.Message, NotificationType.Error);
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId,
- Package.DisplayName,
+ packageName,
new ProgressReport(0f, "Update failed", type: ProgressType.Update), Failed: true));
}
finally
@@ -188,6 +214,29 @@ public partial class PackageCardViewModel : ProgressViewModel
Text = "";
}
}
+
+ public async Task Import()
+ {
+ if (!IsUnknownPackage || Design.IsDesignMode) return;
+
+ PackageImportViewModel viewModel = null!;
+ var dialog = vmFactory.GetDialog(vm =>
+ {
+ vm.PackagePath = new DirectoryPath(Package?.FullPath ?? throw new InvalidOperationException());
+ viewModel = vm;
+ });
+
+ dialog.PrimaryButtonText = Resources.Action_Import;
+ dialog.CloseButtonText = Resources.Action_Cancel;
+ dialog.DefaultButton = ContentDialogButton.Primary;
+
+ var result = await dialog.ShowAsync();
+ if (result == ContentDialogResult.Primary)
+ {
+ viewModel.AddPackageWithCurrentInputs();
+ EventManager.Instance.OnInstalledPackagesChanged();
+ }
+ }
public async Task OpenFolder()
{
@@ -199,7 +248,7 @@ public partial class PackageCardViewModel : ProgressViewModel
private async Task HasUpdate()
{
- if (Package == null)
+ if (Package == null || IsUnknownPackage || Design.IsDesignMode)
return false;
var basePackage = packageFactory[Package.PackageName!];
@@ -224,86 +273,8 @@ public partial class PackageCardViewModel : ProgressViewModel
}
catch (Exception e)
{
- logger.Error(e, $"Error checking {Package.PackageName} for updates");
+ logger.LogError(e, "Error checking {PackageName} for updates", Package.PackageName);
return false;
}
}
-
- ///
- /// Deletes a directory and all of its contents recursively.
- /// Uses Polly to retry the deletion if it fails, up to 5 times with an exponential backoff.
- ///
- ///
- private Task DeleteDirectoryAsync(string targetDirectory)
- {
- var policy = Policy.Handle()
- .WaitAndRetryAsync(3, attempt => TimeSpan.FromMilliseconds(50 * Math.Pow(2, attempt)),
- onRetry: (exception, calculatedWaitDuration) =>
- {
- logger.Warn(
- exception,
- "Deletion of {TargetDirectory} failed. Retrying in {CalculatedWaitDuration}",
- targetDirectory, calculatedWaitDuration);
- });
-
- return policy.ExecuteAsync(async () =>
- {
- await Task.Run(() =>
- {
- DeleteDirectory(targetDirectory);
- });
- });
- }
-
- private void DeleteDirectory(string targetDirectory)
- {
- // Skip if directory does not exist
- if (!Directory.Exists(targetDirectory))
- {
- return;
- }
- // For junction points, delete with recursive false
- if (new DirectoryInfo(targetDirectory).LinkTarget != null)
- {
- logger.Info("Removing junction point {TargetDirectory}", targetDirectory);
- try
- {
- Directory.Delete(targetDirectory, false);
- return;
- }
- catch (IOException ex)
- {
- throw new IOException($"Failed to delete junction point {targetDirectory}", ex);
- }
- }
- // Recursively delete all subdirectories
- var subdirectoryEntries = Directory.GetDirectories(targetDirectory);
- foreach (var subdirectoryPath in subdirectoryEntries)
- {
- DeleteDirectory(subdirectoryPath);
- }
- // Delete all files in the directory
- var fileEntries = Directory.GetFiles(targetDirectory);
- foreach (var filePath in fileEntries)
- {
- try
- {
- File.SetAttributes(filePath, FileAttributes.Normal);
- File.Delete(filePath);
- }
- catch (IOException ex)
- {
- throw new IOException($"Failed to delete file {filePath}", ex);
- }
- }
- // Delete the target directory itself
- try
- {
- Directory.Delete(targetDirectory, false);
- }
- catch (IOException ex)
- {
- throw new IOException($"Failed to delete directory {targetDirectory}", ex);
- }
- }
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
index 6994b721..ac98e1b6 100644
--- a/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
@@ -1,11 +1,13 @@
using System;
+using System.Collections.Generic;
using System.Collections.Immutable;
-using System.Collections.ObjectModel;
+using System.IO;
using System.Linq;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
-using CommunityToolkit.Mvvm.ComponentModel;
+using DynamicData;
+using DynamicData.Binding;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Services;
@@ -17,6 +19,8 @@ using StabilityMatrix.Avalonia.Views.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Helper.Factory;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
using Symbol = FluentIcons.Common.Symbol;
using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
@@ -34,39 +38,73 @@ public partial class PackageManagerViewModel : PageViewModelBase
private readonly IPackageFactory packageFactory;
private readonly ServiceManager dialogFactory;
- public PackageManagerViewModel(ISettingsManager settingsManager, IPackageFactory packageFactory,
- ServiceManager dialogFactory)
+ public override string Title => "Packages";
+ public override IconSource IconSource =>
+ new SymbolIconSource { Symbol = Symbol.Box, IsFilled = true };
+
+ ///
+ /// List of installed packages
+ ///
+ private readonly SourceCache installedPackages = new(p => p.Id);
+
+ ///
+ /// List of indexed packages without a corresponding installed package
+ ///
+ private readonly SourceCache unknownInstalledPackages = new(p => p.Id);
+
+ public IObservableCollection Packages { get; } =
+ new ObservableCollectionExtended();
+
+ public IObservableCollection PackageCards { get; } =
+ new ObservableCollectionExtended();
+
+ public PackageManagerViewModel(
+ ISettingsManager settingsManager,
+ IPackageFactory packageFactory,
+ ServiceManager dialogFactory
+ )
{
this.settingsManager = settingsManager;
this.packageFactory = packageFactory;
this.dialogFactory = dialogFactory;
-
+
EventManager.Instance.InstalledPackagesChanged += OnInstalledPackagesChanged;
+
+ var installed = installedPackages.Connect();
+ var unknown = unknownInstalledPackages.Connect();
+
+ installed
+ .Or(unknown)
+ .DeferUntilLoaded()
+ .Bind(Packages)
+ .Transform(p => dialogFactory.Get(vm =>
+ {
+ vm.Package = p;
+ vm.OnLoadedAsync().SafeFireAndForget();
+ }))
+ .Bind(PackageCards)
+ .Subscribe();
}
- [ObservableProperty] private ObservableCollection packages;
+ public void SetPackages(IEnumerable packages)
+ {
+ installedPackages.Edit(s => s.Load(packages));
+ }
+
+ public void SetUnknownPackages(IEnumerable packages)
+ {
+ unknownInstalledPackages.Edit(s => s.Load(packages));
+ }
- public override bool CanNavigateNext { get; protected set; } = true;
- public override bool CanNavigatePrevious { get; protected set; }
- public override string Title => "Packages";
- public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.Box, IsFilled = true};
-
public override async Task OnLoadedAsync()
{
- if (Design.IsDesignMode) return;
-
- var installedPackages = settingsManager.Settings.InstalledPackages;
- Packages = new ObservableCollection(installedPackages.Select(
- package => dialogFactory.Get(vm =>
- {
- vm.Package = package;
- return vm;
- })));
+ if (Design.IsDesignMode)
+ return;
- foreach (var package in Packages)
- {
- await package.OnLoadedAsync();
- }
+ installedPackages.EditDiff(settingsManager.Settings.InstalledPackages, InstalledPackage.Comparer);
+
+ var currentUnknown = await Task.Run(IndexUnknownPackages);
+ unknownInstalledPackages.Edit(s => s.Load(currentUnknown));
}
public async Task ShowInstallDialog()
@@ -83,16 +121,40 @@ public partial class PackageManagerViewModel : PageViewModelBase
IsPrimaryButtonEnabled = false,
IsSecondaryButtonEnabled = false,
IsFooterVisible = false,
- Content = new InstallerDialog
- {
- DataContext = viewModel
- }
+ Content = new InstallerDialog { DataContext = viewModel }
};
await dialog.ShowAsync();
await OnLoadedAsync();
}
+ private IEnumerable IndexUnknownPackages()
+ {
+ var packageDir = new DirectoryPath(settingsManager.LibraryDir).JoinDir("Packages");
+
+ if (!packageDir.Exists)
+ {
+ yield break;
+ }
+
+ var currentPackages = settingsManager.Settings.InstalledPackages.ToImmutableArray();
+
+ foreach (var subDir in packageDir.Info
+ .EnumerateDirectories()
+ .Select(info => new DirectoryPath(info)))
+ {
+ var expectedLibraryPath = $"Packages{Path.DirectorySeparatorChar}{subDir.Name}";
+
+ // Skip if the package is already installed
+ if (currentPackages.Any(p => p.LibraryPath == expectedLibraryPath))
+ {
+ continue;
+ }
+
+ yield return UnknownInstalledPackage.FromDirectoryName(subDir.Name);
+ }
+ }
+
private void OnInstalledPackagesChanged(object? sender, EventArgs e) =>
OnLoadedAsync().SafeFireAndForget();
}
diff --git a/StabilityMatrix.Avalonia/ViewModels/ProgressItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/ProgressItemViewModel.cs
index ceb7a2de..9d6cc55e 100644
--- a/StabilityMatrix.Avalonia/ViewModels/ProgressItemViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/ProgressItemViewModel.cs
@@ -6,21 +6,16 @@ using StabilityMatrix.Core.Models.Progress;
namespace StabilityMatrix.Avalonia.ViewModels;
-public partial class ProgressItemViewModel : ViewModelBase
+public class ProgressItemViewModel : ProgressItemViewModelBase
{
- [ObservableProperty] private Guid id;
- [ObservableProperty] private string name;
- [ObservableProperty] private ProgressReport progress;
- [ObservableProperty] private bool failed;
- [ObservableProperty] private string? progressText;
-
public ProgressItemViewModel(ProgressItem progressItem)
{
Id = progressItem.ProgressId;
Name = progressItem.Name;
- Progress = progressItem.Progress;
+ Progress.Value = progressItem.Progress.Percentage;
Failed = progressItem.Failed;
- ProgressText = GetProgressText(Progress);
+ Progress.Text = GetProgressText(progressItem.Progress);
+ Progress.IsIndeterminate = progressItem.Progress.IsIndeterminate;
EventManager.Instance.ProgressChanged += OnProgressChanged;
}
@@ -30,9 +25,10 @@ public partial class ProgressItemViewModel : ViewModelBase
if (e.ProgressId != Id)
return;
- Progress = e.Progress;
+ Progress.Value = e.Progress.Percentage;
Failed = e.Failed;
- ProgressText = GetProgressText(Progress);
+ Progress.Text = GetProgressText(e.Progress);
+ Progress.IsIndeterminate = e.Progress.IsIndeterminate;
}
private string GetProgressText(ProgressReport report)
diff --git a/StabilityMatrix.Avalonia/ViewModels/ProgressManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/ProgressManagerViewModel.cs
index 7109a042..99ee8ebd 100644
--- a/StabilityMatrix.Avalonia/ViewModels/ProgressManagerViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/ProgressManagerViewModel.cs
@@ -1,13 +1,21 @@
using System;
+using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
+using Avalonia.Collections;
+using Avalonia.Controls.Notifications;
using CommunityToolkit.Mvvm.ComponentModel;
using FluentAvalonia.UI.Controls;
+using Polly;
+using StabilityMatrix.Avalonia.Models;
+using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.Views;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Services;
using Symbol = FluentIcons.Common.Symbol;
using SymbolIconSource = FluentIcons.FluentAvalonia.SymbolIconSource;
@@ -16,17 +24,72 @@ namespace StabilityMatrix.Avalonia.ViewModels;
[View(typeof(ProgressManagerPage))]
public partial class ProgressManagerViewModel : PageViewModelBase
{
+ private readonly INotificationService notificationService;
+
public override string Title => "Download Manager";
public override IconSource IconSource => new SymbolIconSource {Symbol = Symbol.ArrowCircleDown, IsFilled = true};
- [ObservableProperty]
- private ObservableCollection progressItems;
+ public AvaloniaList ProgressItems { get; } = new();
+
+ public ProgressManagerViewModel(
+ ITrackedDownloadService trackedDownloadService,
+ INotificationService notificationService)
+ {
+ this.notificationService = notificationService;
+
+ // Attach to the event
+ trackedDownloadService.DownloadAdded += TrackedDownloadService_OnDownloadAdded;
+ }
- public ProgressManagerViewModel()
+ private void TrackedDownloadService_OnDownloadAdded(object? sender, TrackedDownload e)
{
- ProgressItems = new ObservableCollection();
+ var vm = new DownloadProgressItemViewModel(e);
+
+ // Attach notification handlers
+ e.ProgressStateChanged += (s, state) =>
+ {
+ var download = s as TrackedDownload;
+
+ switch (state)
+ {
+ case ProgressState.Success:
+ notificationService.Show("Download Completed", $"Download of {e.FileName} completed successfully.", NotificationType.Success);
+ break;
+ case ProgressState.Failed:
+ var msg = "";
+ if (download?.Exception is { } exception)
+ {
+ msg = $"({exception.GetType().Name}) {exception.Message}";
+ }
+ notificationService.ShowPersistent("Download Failed", $"Download of {e.FileName} failed: {msg}", NotificationType.Error);
+ break;
+ case ProgressState.Cancelled:
+ notificationService.Show("Download Cancelled", $"Download of {e.FileName} was cancelled.", NotificationType.Warning);
+ break;
+ }
+ };
+
+ ProgressItems.Add(vm);
}
+ public void AddDownloads(IEnumerable downloads)
+ {
+ foreach (var download in downloads)
+ {
+ if (ProgressItems.Any(vm => vm.Id == download.Id))
+ {
+ continue;
+ }
+ var vm = new DownloadProgressItemViewModel(download);
+ ProgressItems.Add(vm);
+ }
+ }
+
+ private void ShowFailedNotification(string title, string message)
+ {
+ notificationService.ShowPersistent(title, message, NotificationType.Error);
+ }
+
public void StartEventListener()
{
EventManager.Instance.ProgressChanged += OnProgressChanged;
@@ -34,12 +97,7 @@ public partial class ProgressManagerViewModel : PageViewModelBase
public void ClearDownloads()
{
- if (!ProgressItems.Any(p => Math.Abs(p.Progress.Percentage - 100) < 0.01f || p.Failed))
- return;
-
- var itemsInProgress = ProgressItems
- .Where(p => p.Progress.Percentage < 100 && !p.Failed).ToList();
- ProgressItems = new ObservableCollection(itemsInProgress);
+ ProgressItems.RemoveAll(ProgressItems.Where(x => x.IsCompleted));
}
private void OnProgressChanged(object? sender, ProgressItem e)
diff --git a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
index 54026801..ba33442d 100644
--- a/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
+++ b/StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
@@ -2,6 +2,8 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel;
+using System.Diagnostics;
+using System.Globalization;
using System.IO;
using System.Linq;
using System.Reflection;
@@ -11,12 +13,14 @@ using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls.Notifications;
using Avalonia.Styling;
+using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
using NLog;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Helpers;
+using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
@@ -45,6 +49,7 @@ public partial class SettingsViewModel : PageViewModelBase
private readonly IPrerequisiteHelper prerequisiteHelper;
private readonly IPyRunner pyRunner;
private readonly ServiceManager dialogFactory;
+ private readonly ITrackedDownloadService trackedDownloadService;
public SharedState SharedState { get; }
@@ -65,6 +70,11 @@ public partial class SettingsViewModel : PageViewModelBase
"System",
};
+ [ObservableProperty] private CultureInfo selectedLanguage;
+
+ // ReSharper disable once MemberCanBeMadeStatic.Global
+ public IReadOnlyList AvailableLanguages => Cultures.SupportedCultures;
+
public IReadOnlyList AnimationScaleOptions { get; } = new[]
{
0f,
@@ -103,17 +113,20 @@ public partial class SettingsViewModel : PageViewModelBase
IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner,
ServiceManager dialogFactory,
- SharedState sharedState)
+ SharedState sharedState,
+ ITrackedDownloadService trackedDownloadService)
{
this.notificationService = notificationService;
this.settingsManager = settingsManager;
this.prerequisiteHelper = prerequisiteHelper;
this.pyRunner = pyRunner;
this.dialogFactory = dialogFactory;
+ this.trackedDownloadService = trackedDownloadService;
SharedState = sharedState;
SelectedTheme = settingsManager.Settings.Theme ?? AvailableThemes[1];
+ SelectedLanguage = Cultures.GetSupportedCultureOrDefault(settingsManager.Settings.Language);
RemoveSymlinksOnShutdown = settingsManager.Settings.RemoveFolderLinksOnShutdown;
SelectedAnimationScale = settingsManager.Settings.AnimationScale;
@@ -142,6 +155,43 @@ public partial class SettingsViewModel : PageViewModelBase
_ => ThemeVariant.Default
};
}
+
+ partial void OnSelectedLanguageChanged(CultureInfo? oldValue, CultureInfo newValue)
+ {
+ if (oldValue is null || newValue.Name == Cultures.Current.Name) return;
+ // Set locale
+ if (AvailableLanguages.Contains(newValue))
+ {
+ Logger.Info("Changing language from {Old} to {New}",
+ oldValue, newValue);
+
+ Cultures.TrySetSupportedCulture(newValue);
+ settingsManager.Transaction(s => s.Language = newValue.Name);
+
+ var dialog = new BetterContentDialog
+ {
+ Title = Resources.Label_RelaunchRequired,
+ Content = Resources.Text_RelaunchRequiredToApplyLanguage,
+ DefaultButton = ContentDialogButton.Primary,
+ PrimaryButtonText = Resources.Action_Relaunch,
+ CloseButtonText = Resources.Action_RelaunchLater
+ };
+
+ Dispatcher.UIThread.InvokeAsync(async () =>
+ {
+ if (await dialog.ShowAsync() == ContentDialogResult.Primary)
+ {
+ Process.Start(Compat.AppCurrentPath);
+ App.Shutdown();
+ }
+ });
+ }
+ else
+ {
+ Logger.Info("Requested invalid language change from {Old} to {New}",
+ oldValue, newValue);
+ }
+ }
partial void OnRemoveSymlinksOnShutdownChanged(bool value)
{
@@ -329,6 +379,50 @@ public partial class SettingsViewModel : PageViewModelBase
"Stability Matrix has been added to the Start Menu for all users.", NotificationType.Success);
}
+ public async Task PickNewDataDirectory()
+ {
+ var viewModel = dialogFactory.Get();
+ var dialog = new BetterContentDialog
+ {
+ IsPrimaryButtonEnabled = false,
+ IsSecondaryButtonEnabled = false,
+ IsFooterVisible = false,
+ Content = new SelectDataDirectoryDialog
+ {
+ DataContext = viewModel
+ }
+ };
+
+ var result = await dialog.ShowAsync();
+ if (result == ContentDialogResult.Primary)
+ {
+ // 1. For portable mode, call settings.SetPortableMode()
+ if (viewModel.IsPortableMode)
+ {
+ settingsManager.SetPortableMode();
+ }
+ // 2. For custom path, call settings.SetLibraryPath(path)
+ else
+ {
+ settingsManager.SetLibraryPath(viewModel.DataDirectory);
+ }
+
+ // Restart
+ var restartDialog = new BetterContentDialog
+ {
+ Title = "Restart required",
+ Content = "Stability Matrix must be restarted for the changes to take effect.",
+ PrimaryButtonText = "Restart",
+ DefaultButton = ContentDialogButton.Primary,
+ IsSecondaryButtonEnabled = false,
+ };
+ await restartDialog.ShowAsync();
+
+ Process.Start(Compat.AppCurrentPath);
+ App.Shutdown();
+ }
+ }
+
#endregion
#region Debug Section
@@ -406,6 +500,32 @@ public partial class SettingsViewModel : PageViewModelBase
// Use try-catch to generate traceback information
throw new OperationCanceledException("Example Message");
}
+
+ [RelayCommand]
+ private async Task DebugTrackedDownload()
+ {
+ var textFields = new TextBoxField[]
+ {
+ new()
+ {
+ Label = "Url",
+ },
+ new()
+ {
+ Label = "File path"
+ }
+ };
+
+ var dialog = DialogHelper.CreateTextEntryDialog("Add download", "", textFields);
+
+ if (await dialog.ShowAsync() == ContentDialogResult.Primary)
+ {
+ var url = textFields[0].Text;
+ var filePath = textFields[1].Text;
+ var download = trackedDownloadService.NewDownload(new Uri(url), new FilePath(filePath));
+ download.Start();
+ }
+ }
#endregion
#region Info Section
diff --git a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml
index a8352c5d..2a4a3541 100644
--- a/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml
+++ b/StabilityMatrix.Avalonia/Views/CheckpointBrowserPage.axaml
@@ -8,7 +8,8 @@
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:vm="clr-namespace:StabilityMatrix.Avalonia.ViewModels.CheckpointManager"
xmlns:checkpointBrowser="clr-namespace:StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser"
- mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="600"
+ xmlns:avalonia="clr-namespace:Projektanker.Icons.Avalonia;assembly=Projektanker.Icons.Avalonia"
+ mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="700"
x:DataType="viewModels:CheckpointBrowserViewModel"
d:DataContext="{x:Static designData:DesignData.CheckpointBrowserViewModel}"
x:CompileBindings="True"
@@ -40,7 +41,7 @@
Margin="0,8,0,8"
Height="300"
StretchDirection="Both"
- CornerRadius="4"
+ CornerRadius="8"
VerticalContentAlignment="Top"
HorizontalContentAlignment="Center"
Source="{Binding CardImage}"
@@ -274,13 +275,30 @@
+
-
diff --git a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
index bb98b979..21178818 100644
--- a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
+++ b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml
@@ -239,17 +239,19 @@
-
+
-
-
-
+
+
+
+
+
-
-
-
-
-
+ Text="Drag & drop checkpoints here to import"
+ IsVisible="{Binding !CheckpointFiles.Count}"/>
+ IsVisible="{Binding Progress.IsTextVisible}" />
+ IsVisible="{Binding Progress.IsProgressVisible}"
+ Value="{Binding Progress.Value, FallbackValue=20}" />
-
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml
index 2332b32d..36961405 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/InstallerDialog.axaml
@@ -135,7 +135,8 @@
TextWrapping="Wrap"
IsVisible="{Binding SelectedPackage.Disclaimer, Converter={x:Static StringConverters.IsNotNullOrEmpty}}"/>
-
+
-
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml
new file mode 100644
index 00000000..8cdbf093
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml
@@ -0,0 +1,51 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs
new file mode 100644
index 00000000..a5f8cf50
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/PackageImportDialog.axaml.cs
@@ -0,0 +1,17 @@
+using Avalonia.Markup.Xaml;
+using StabilityMatrix.Avalonia.Controls;
+
+namespace StabilityMatrix.Avalonia.Views.Dialogs;
+
+public partial class PackageImportDialog : UserControlBase
+{
+ public PackageImportDialog()
+ {
+ InitializeComponent();
+ }
+
+ private void InitializeComponent()
+ {
+ AvaloniaXamlLoader.Load(this);
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml
index bdbccdc5..2fe90df8 100644
--- a/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml
+++ b/StabilityMatrix.Avalonia/Views/Dialogs/SelectDataDirectoryDialog.axaml
@@ -50,7 +50,7 @@
+ Margin="0,32,0,0" />
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml.cs b/StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml.cs
new file mode 100644
index 00000000..28d9880a
--- /dev/null
+++ b/StabilityMatrix.Avalonia/Views/NewCheckpointsPage.axaml.cs
@@ -0,0 +1,11 @@
+using StabilityMatrix.Avalonia.Controls;
+
+namespace StabilityMatrix.Avalonia.Views;
+
+public partial class NewCheckpointsPage : UserControlBase
+{
+ public NewCheckpointsPage()
+ {
+ InitializeComponent();
+ }
+}
diff --git a/StabilityMatrix.Avalonia/Views/PackageManagerPage.axaml b/StabilityMatrix.Avalonia/Views/PackageManagerPage.axaml
index 74242ac3..72f22091 100644
--- a/StabilityMatrix.Avalonia/Views/PackageManagerPage.axaml
+++ b/StabilityMatrix.Avalonia/Views/PackageManagerPage.axaml
@@ -4,11 +4,11 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:viewModels="clr-namespace:StabilityMatrix.Avalonia.ViewModels"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
- xmlns:models="clr-namespace:StabilityMatrix.Core.Models;assembly=StabilityMatrix.Core"
xmlns:controls="clr-namespace:StabilityMatrix.Avalonia.Controls"
xmlns:designData="clr-namespace:StabilityMatrix.Avalonia.DesignData"
xmlns:packageManager="clr-namespace:StabilityMatrix.Avalonia.ViewModels.PackageManager"
xmlns:faicon="clr-namespace:Projektanker.Icons.Avalonia;assembly=Projektanker.Icons.Avalonia"
+ xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:DataType="viewModels:PackageManagerViewModel"
x:CompileBindings="True"
@@ -24,7 +24,7 @@
+ ItemsSource="{Binding PackageCards}">
@@ -36,7 +36,7 @@
+
+
diff --git a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml
index 8e3d6d0d..cb745ba9 100644
--- a/StabilityMatrix.Avalonia/Views/SettingsPage.axaml
+++ b/StabilityMatrix.Avalonia/Views/SettingsPage.axaml
@@ -1,317 +1,368 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/StabilityMatrix.Core/Converters/Json/StringJsonConverter.cs b/StabilityMatrix.Core/Converters/Json/StringJsonConverter.cs
new file mode 100644
index 00000000..16904246
--- /dev/null
+++ b/StabilityMatrix.Core/Converters/Json/StringJsonConverter.cs
@@ -0,0 +1,34 @@
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace StabilityMatrix.Core.Converters.Json;
+
+///
+/// Json converter for types that serialize to string by `ToString()` and
+/// can be created by `Activator.CreateInstance(Type, string)`
+///
+public class StringJsonConverter : JsonConverter
+{
+ ///
+ public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ if (reader.TokenType != JsonTokenType.String)
+ {
+ throw new JsonException();
+ }
+
+ var value = reader.GetString();
+ if (value is null)
+ {
+ throw new JsonException();
+ }
+
+ return (T?) Activator.CreateInstance(typeToConvert, value);
+ }
+
+ ///
+ public override void Write(Utf8JsonWriter writer, T? value, JsonSerializerOptions options)
+ {
+ writer.WriteStringValue(value?.ToString());
+ }
+}
diff --git a/StabilityMatrix.Core/Helper/FileHash.cs b/StabilityMatrix.Core/Helper/FileHash.cs
index a29d87c0..e0430256 100644
--- a/StabilityMatrix.Core/Helper/FileHash.cs
+++ b/StabilityMatrix.Core/Helper/FileHash.cs
@@ -46,8 +46,8 @@ public static class FileHash
var hash = await GetHashAsync(SHA256.Create(), stream, buffer, totalBytesRead =>
{
- progress?.Report(new ProgressReport(totalBytesRead, totalBytes));
- });
+ progress?.Report(new ProgressReport(totalBytesRead, totalBytes, type: ProgressType.Hashing));
+ }).ConfigureAwait(false);
return hash;
}
finally
diff --git a/StabilityMatrix.Core/Helper/HardwareHelper.cs b/StabilityMatrix.Core/Helper/HardwareHelper.cs
index 66d678ca..eaef9df8 100644
--- a/StabilityMatrix.Core/Helper/HardwareHelper.cs
+++ b/StabilityMatrix.Core/Helper/HardwareHelper.cs
@@ -123,6 +123,16 @@ public static partial class HardwareHelper
{
return IterGpuInfo().Any(gpu => gpu.IsAmd);
}
+
+ // Set ROCm for default if AMD and Linux
+ public static bool PreferRocm() => !HardwareHelper.HasNvidiaGpu()
+ && HardwareHelper.HasAmdGpu()
+ && Compat.IsLinux;
+
+ // Set DirectML for default if AMD and Windows
+ public static bool PreferDirectML() => !HardwareHelper.HasNvidiaGpu()
+ && HardwareHelper.HasAmdGpu()
+ && Compat.IsWindows;
}
public enum Level
diff --git a/StabilityMatrix.Core/Helper/PropertyComparer.cs b/StabilityMatrix.Core/Helper/PropertyComparer.cs
new file mode 100644
index 00000000..f00f9665
--- /dev/null
+++ b/StabilityMatrix.Core/Helper/PropertyComparer.cs
@@ -0,0 +1,24 @@
+namespace StabilityMatrix.Core.Helper;
+
+public class PropertyComparer : IEqualityComparer where T : class
+{
+ private Func Expr { get; set; }
+
+ public PropertyComparer(Func expr)
+ {
+ Expr = expr;
+ }
+ public bool Equals(T? x, T? y)
+ {
+ if (x == null || y == null) return false;
+
+ var first = Expr.Invoke(x);
+ var second = Expr.Invoke(y);
+
+ return first.Equals(second);
+ }
+ public int GetHashCode(T obj)
+ {
+ return obj.GetHashCode();
+ }
+}
diff --git a/StabilityMatrix.Core/Models/Api/CivitModelFpType.cs b/StabilityMatrix.Core/Models/Api/CivitModelFpType.cs
index f1250c3f..817e2ea2 100644
--- a/StabilityMatrix.Core/Models/Api/CivitModelFpType.cs
+++ b/StabilityMatrix.Core/Models/Api/CivitModelFpType.cs
@@ -8,6 +8,8 @@ namespace StabilityMatrix.Core.Models.Api;
[SuppressMessage("ReSharper", "InconsistentNaming")]
public enum CivitModelFpType
{
+ bf16,
fp16,
- fp32
+ fp32,
+ tf32
}
diff --git a/StabilityMatrix.Core/Models/Api/CivitModelType.cs b/StabilityMatrix.Core/Models/Api/CivitModelType.cs
index 944115be..0331e252 100644
--- a/StabilityMatrix.Core/Models/Api/CivitModelType.cs
+++ b/StabilityMatrix.Core/Models/Api/CivitModelType.cs
@@ -9,23 +9,29 @@ namespace StabilityMatrix.Core.Models.Api;
[SuppressMessage("ReSharper", "InconsistentNaming")]
public enum CivitModelType
{
- Unknown,
[ConvertTo(SharedFolderType.StableDiffusion)]
Checkpoint,
[ConvertTo(SharedFolderType.TextualInversion)]
TextualInversion,
[ConvertTo(SharedFolderType.Hypernetwork)]
Hypernetwork,
- AestheticGradient,
[ConvertTo(SharedFolderType.Lora)]
LORA,
[ConvertTo(SharedFolderType.ControlNet)]
Controlnet,
- Poses,
- [ConvertTo(SharedFolderType.StableDiffusion)]
- Model,
[ConvertTo(SharedFolderType.LyCORIS)]
LoCon,
+ [ConvertTo(SharedFolderType.VAE)]
+ VAE,
+
+ // Unused/obsolete/unknown/meta options
+ AestheticGradient,
+ Model,
+ Poses,
+ Upscaler,
+ Wildcards,
+ Workflows,
Other,
All,
+ Unknown
}
diff --git a/StabilityMatrix.Core/Models/Api/CivitModelsRequest.cs b/StabilityMatrix.Core/Models/Api/CivitModelsRequest.cs
index c5a11dd3..66ea7c48 100644
--- a/StabilityMatrix.Core/Models/Api/CivitModelsRequest.cs
+++ b/StabilityMatrix.Core/Models/Api/CivitModelsRequest.cs
@@ -117,4 +117,22 @@ public class CivitModelsRequest
///
[AliasAs("baseModels")]
public string? BaseModel { get; set; }
+
+ [AliasAs("ids")]
+ public string CommaSeparatedModelIds { get; set; }
+
+ public override string ToString()
+ {
+ return $"Page: {Page}, " +
+ $"Query: {Query}, " +
+ $"Tag: {Tag}, " +
+ $"Username: {Username}, " +
+ $"Types: {Types}, " +
+ $"Sort: {Sort}, " +
+ $"Period: {Period}, " +
+ $"Rating: {Rating}, " +
+ $"Nsfw: {Nsfw}, " +
+ $"BaseModel: {BaseModel}, " +
+ $"CommaSeparatedModelIds: {CommaSeparatedModelIds}";
+ }
}
diff --git a/StabilityMatrix.Core/Models/Api/CivitSortMode.cs b/StabilityMatrix.Core/Models/Api/CivitSortMode.cs
index 6c522865..241f7ff6 100644
--- a/StabilityMatrix.Core/Models/Api/CivitSortMode.cs
+++ b/StabilityMatrix.Core/Models/Api/CivitSortMode.cs
@@ -11,5 +11,7 @@ public enum CivitSortMode
[EnumMember(Value = "Most Downloaded")]
MostDownloaded,
[EnumMember(Value = "Newest")]
- Newest
+ Newest,
+ [EnumMember(Value = "Installed")]
+ Installed,
}
diff --git a/StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs b/StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs
new file mode 100644
index 00000000..82e703f3
--- /dev/null
+++ b/StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs
@@ -0,0 +1,44 @@
+using System.Diagnostics;
+using System.Text.Json;
+using StabilityMatrix.Core.Models.Api;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Models;
+
+public class CivitPostDownloadContextAction : IContextAction
+{
+ ///
+ public object? Context { get; set; }
+
+ public static CivitPostDownloadContextAction FromCivitFile(CivitFile file)
+ {
+ return new CivitPostDownloadContextAction
+ {
+ Context = file.Hashes.BLAKE3
+ };
+ }
+
+ public void Invoke(ISettingsManager settingsManager)
+ {
+ var result = Context as string;
+
+ if (Context is JsonElement jsonElement)
+ {
+ result = jsonElement.GetString();
+ }
+
+ if (result is null)
+ {
+ Debug.WriteLine($"Context {Context} is not a string.");
+ return;
+ }
+
+ Debug.WriteLine($"Adding {result} to installed models.");
+ settingsManager.Transaction(
+ s =>
+ {
+ s.InstalledModelHashes ??= new HashSet();
+ s.InstalledModelHashes.Add(result);
+ });
+ }
+}
diff --git a/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs b/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs
index a80b0ab9..c125e7a1 100644
--- a/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs
+++ b/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs
@@ -1,14 +1,19 @@
using System.Diagnostics.CodeAnalysis;
+using System.Text.Json.Serialization;
+using StabilityMatrix.Core.Converters.Json;
namespace StabilityMatrix.Core.Models.FileInterfaces;
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+[JsonConverter(typeof(StringJsonConverter))]
public class DirectoryPath : FileSystemPath, IPathObject
{
private DirectoryInfo? info;
// ReSharper disable once MemberCanBePrivate.Global
+ [JsonIgnore]
public DirectoryInfo Info => info ??= new DirectoryInfo(FullPath);
+ [JsonIgnore]
public bool IsSymbolicLink
{
get
@@ -21,14 +26,17 @@ public class DirectoryPath : FileSystemPath, IPathObject
///
/// Gets a value indicating whether the directory exists.
///
+ [JsonIgnore]
public bool Exists => Info.Exists;
///
+ [JsonIgnore]
public string Name => Info.Name;
///
/// Get the parent directory.
///
+ [JsonIgnore]
public DirectoryPath? Parent => Info.Parent == null
? null : new DirectoryPath(Info.Parent);
diff --git a/StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs b/StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs
index d8a54fcd..03e48728 100644
--- a/StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs
+++ b/StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs
@@ -1,15 +1,20 @@
using System.Diagnostics.CodeAnalysis;
using System.Text;
+using System.Text.Json.Serialization;
+using StabilityMatrix.Core.Converters.Json;
namespace StabilityMatrix.Core.Models.FileInterfaces;
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
+[JsonConverter(typeof(StringJsonConverter))]
public class FilePath : FileSystemPath, IPathObject
{
private FileInfo? _info;
// ReSharper disable once MemberCanBePrivate.Global
+ [JsonIgnore]
public FileInfo Info => _info ??= new FileInfo(FullPath);
+ [JsonIgnore]
public bool IsSymbolicLink
{
get
@@ -19,13 +24,16 @@ public class FilePath : FileSystemPath, IPathObject
}
}
+ [JsonIgnore]
public bool Exists => Info.Exists;
+ [JsonIgnore]
public string Name => Info.Name;
///
/// Get the directory of the file.
///
+ [JsonIgnore]
public DirectoryPath? Directory
{
get
@@ -115,12 +123,22 @@ public class FilePath : FileSystemPath, IPathObject
return File.WriteAllBytesAsync(FullPath, bytes, ct);
}
+ ///
+ /// Move the file to a directory.
+ ///
+ public FilePath MoveTo(FilePath destinationFile)
+ {
+ Info.MoveTo(destinationFile.FullPath, true);
+ // Return the new path
+ return destinationFile;
+ }
+
///
/// Move the file to a directory.
///
public async Task MoveToAsync(DirectoryPath directory)
{
- await Task.Run(() => Info.MoveTo(directory.FullPath));
+ await Task.Run(() => Info.MoveTo(directory.FullPath)).ConfigureAwait(false);
// Return the new path
return directory.JoinFile(this);
}
@@ -130,7 +148,7 @@ public class FilePath : FileSystemPath, IPathObject
///
public async Task MoveToAsync(FilePath destinationFile)
{
- await Task.Run(() => Info.MoveTo(destinationFile.FullPath));
+ await Task.Run(() => Info.MoveTo(destinationFile.FullPath)).ConfigureAwait(false);
// Return the new path
return destinationFile;
}
diff --git a/StabilityMatrix.Core/Models/IContextAction.cs b/StabilityMatrix.Core/Models/IContextAction.cs
new file mode 100644
index 00000000..805531be
--- /dev/null
+++ b/StabilityMatrix.Core/Models/IContextAction.cs
@@ -0,0 +1,9 @@
+using System.Text.Json.Serialization;
+
+namespace StabilityMatrix.Core.Models;
+
+[JsonDerivedType(typeof(CivitPostDownloadContextAction), "CivitPostDownload")]
+public interface IContextAction
+{
+ object? Context { get; set; }
+}
diff --git a/StabilityMatrix.Core/Models/InstalledPackage.cs b/StabilityMatrix.Core/Models/InstalledPackage.cs
index 627582a1..26d00423 100644
--- a/StabilityMatrix.Core/Models/InstalledPackage.cs
+++ b/StabilityMatrix.Core/Models/InstalledPackage.cs
@@ -39,7 +39,7 @@ public class InstalledPackage
public DateTimeOffset? LastUpdateCheck { get; set; }
public bool UpdateAvailable { get; set; }
-
+
///
/// Get the path as a relative sub-path of the relative path.
/// If not a sub-path, return null.
@@ -158,6 +158,9 @@ public class InstalledPackage
LibraryPath = System.IO.Path.Combine("Packages", packageFolderName);
}
+ public static IEqualityComparer Comparer { get; } =
+ new PropertyComparer(p => p.Id);
+
protected bool Equals(InstalledPackage other)
{
return Id.Equals(other.Id);
diff --git a/StabilityMatrix.Core/Models/LaunchOptionCard.cs b/StabilityMatrix.Core/Models/LaunchOptionCard.cs
index 3e5dddf5..8882029c 100644
--- a/StabilityMatrix.Core/Models/LaunchOptionCard.cs
+++ b/StabilityMatrix.Core/Models/LaunchOptionCard.cs
@@ -45,8 +45,13 @@ public readonly record struct LaunchOptionCard
// During card creation, store dict of options with initial values
var initialOptions = new Dictionary();
- // Dict of
- var launchArgsDict = launchArgs.ToDictionary(launchArg => launchArg.Name);
+ // To dictionary ignoring duplicates
+ var launchArgsDict = launchArgs
+ .ToLookup(launchArg => launchArg.Name)
+ .ToDictionary(
+ group => group.Key,
+ group => group.First()
+ );
// Create cards
foreach (var definition in definitions)
diff --git a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
index 6dddc6ef..025f87e2 100644
--- a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
+++ b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
@@ -139,7 +139,7 @@ public abstract class BaseGitPackage : BasePackage
}
public override async Task DownloadPackage(string version, bool isCommitHash,
- IProgress? progress = null)
+ string? branch, IProgress? progress = null)
{
var downloadUrl = GetDownloadUrl(version, isCommitHash);
@@ -151,7 +151,7 @@ public abstract class BaseGitPackage : BasePackage
await DownloadService
.DownloadToFileAsync(downloadUrl, DownloadLocation, progress: progress)
.ConfigureAwait(false);
-
+
progress?.Report(new ProgressReport(100, message: "Download Complete"));
return version;
@@ -246,7 +246,8 @@ public abstract class BaseGitPackage : BasePackage
{
var releases = await GetAllReleases().ConfigureAwait(false);
var latestRelease = releases.First(x => includePrerelease || !x.Prerelease);
- await DownloadPackage(latestRelease.TagName, false, progress).ConfigureAwait(false);
+ await DownloadPackage(latestRelease.TagName, false, null, progress)
+ .ConfigureAwait(false);
await InstallPackage(progress).ConfigureAwait(false);
return latestRelease.TagName;
}
@@ -260,8 +261,9 @@ public abstract class BaseGitPackage : BasePackage
{
throw new Exception("No commits found for branch");
}
-
- await DownloadPackage(latestCommit.Sha, true, progress).ConfigureAwait(false);
+
+ await DownloadPackage(latestCommit.Sha, true, installedPackage.InstalledBranch, progress)
+ .ConfigureAwait(false);
await InstallPackage(progress).ConfigureAwait(false);
return latestCommit.Sha;
}
diff --git a/StabilityMatrix.Core/Models/Packages/BasePackage.cs b/StabilityMatrix.Core/Models/Packages/BasePackage.cs
index 285e5b10..34b9acff 100644
--- a/StabilityMatrix.Core/Models/Packages/BasePackage.cs
+++ b/StabilityMatrix.Core/Models/Packages/BasePackage.cs
@@ -18,6 +18,7 @@ public abstract class BasePackage
public abstract string LicenseType { get; }
public abstract string LicenseUrl { get; }
public virtual string Disclaimer => string.Empty;
+ public virtual bool OfferInOneClickInstaller => true;
///
/// Primary command to launch the package. 'Launch' buttons uses this.
@@ -33,7 +34,7 @@ public abstract class BasePackage
public virtual bool ShouldIgnoreReleases => false;
public virtual bool UpdateAvailable { get; set; }
- public abstract Task DownloadPackage(string version, bool isCommitHash,
+ public abstract Task DownloadPackage(string version, bool isCommitHash, string? branch,
IProgress? progress = null);
public abstract Task InstallPackage(IProgress? progress = null);
public abstract Task RunPackage(string installedPackagePath, string command, string arguments);
@@ -72,8 +73,8 @@ public abstract class BasePackage
public abstract Task> GetAllBranches();
public abstract Task> GetAllReleases();
- public abstract string DownloadLocation { get; }
- public abstract string InstallLocation { get; set; }
+ public virtual string? DownloadLocation { get; }
+ public virtual string? InstallLocation { get; set; }
public event EventHandler? ConsoleOutput;
public event EventHandler? Exited;
diff --git a/StabilityMatrix.Core/Models/Packages/Fooocus.cs b/StabilityMatrix.Core/Models/Packages/Fooocus.cs
new file mode 100644
index 00000000..b9324f3e
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Packages/Fooocus.cs
@@ -0,0 +1,122 @@
+using System.Diagnostics;
+using System.Text.RegularExpressions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.Cache;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Models.Packages;
+
+public class Fooocus : BaseGitPackage
+{
+ public Fooocus(IGithubApiCache githubApi, ISettingsManager settingsManager,
+ IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper) : base(githubApi,
+ settingsManager, downloadService, prerequisiteHelper)
+ {
+ }
+
+ public override string Name => "Fooocus";
+ public override string DisplayName { get; set; } = "Fooocus";
+ public override string Author => "lllyasviel";
+
+ public override string Blurb =>
+ "Fooocus is a rethinking of Stable Diffusion and Midjourney’s designs";
+
+ public override string LicenseType => "GPL-3.0";
+ public override string LicenseUrl => "https://github.com/lllyasviel/Fooocus/blob/main/LICENSE";
+ public override string LaunchCommand => "launch.py";
+
+ public override Uri PreviewImageUri =>
+ new("https://user-images.githubusercontent.com/19834515/261830306-f79c5981-cf80-4ee3-b06b-3fef3f8bfbc7.png");
+
+ public override List LaunchOptions => new()
+ {
+ LaunchOptionDefinition.Extras
+ };
+
+ public override Dictionary> SharedFolders => new()
+ {
+ [SharedFolderType.StableDiffusion] = new[] {"models/checkpoints"},
+ [SharedFolderType.Diffusers] = new[] {"models/diffusers"},
+ [SharedFolderType.Lora] = new[] {"models/loras"},
+ [SharedFolderType.CLIP] = new[] {"models/clip"},
+ [SharedFolderType.TextualInversion] = new[] {"models/embeddings"},
+ [SharedFolderType.VAE] = new[] {"models/vae"},
+ [SharedFolderType.ApproxVAE] = new[] {"models/vae_approx"},
+ [SharedFolderType.ControlNet] = new[] {"models/controlnet"},
+ [SharedFolderType.GLIGEN] = new[] {"models/gligen"},
+ [SharedFolderType.ESRGAN] = new[] {"models/upscale_models"},
+ [SharedFolderType.Hypernetwork] = new[] {"models/hypernetworks"}
+ };
+
+ public override async Task GetLatestVersion()
+ {
+ var release = await GetLatestRelease().ConfigureAwait(false);
+ return release.TagName!;
+ }
+
+ public override async Task InstallPackage(IProgress? progress = null)
+ {
+ await base.InstallPackage(progress).ConfigureAwait(false);
+ var venvRunner = await SetupVenv(InstallLocation).ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
+
+ var torchVersion = "cpu";
+ var gpus = HardwareHelper.IterGpuInfo().ToList();
+
+ if (gpus.Any(g => g.IsNvidia))
+ {
+ torchVersion = "cu118";
+ }
+ else if (HardwareHelper.PreferRocm())
+ {
+ torchVersion = "rocm5.4.2";
+ }
+
+ await venvRunner
+ .PipInstall(
+ $"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersion}",
+ OnConsoleOutput).ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(-1f, "Installing requirements...",
+ isIndeterminate: true));
+ await venvRunner.PipInstall("-r requirements_versions.txt", OnConsoleOutput)
+ .ConfigureAwait(false);
+ }
+
+ public override async Task RunPackage(string installedPackagePath, string command, string arguments)
+ {
+ await SetupVenv(installedPackagePath).ConfigureAwait(false);
+
+ void HandleConsoleOutput(ProcessOutput s)
+ {
+ OnConsoleOutput(s);
+
+ if (s.Text.Contains("Use the app with", StringComparison.OrdinalIgnoreCase))
+ {
+ var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
+ var match = regex.Match(s.Text);
+ if (match.Success)
+ {
+ WebUrl = match.Value;
+ }
+ OnStartupComplete(WebUrl);
+ }
+ }
+
+ void HandleExit(int i)
+ {
+ Debug.WriteLine($"Venv process exited with code {i}");
+ OnExit(i);
+ }
+
+ var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
+
+ VenvRunner?.RunDetached(
+ args.TrimEnd(),
+ HandleConsoleOutput,
+ HandleExit);
+ }
+}
diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs
index 7ac7592a..cdc26769 100644
--- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs
+++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs
@@ -116,7 +116,7 @@ public class InvokeAI : BaseGitPackage
public override Task GetLatestVersion() => Task.FromResult("main");
- public override Task DownloadPackage(string version, bool isCommitHash,
+ public override Task DownloadPackage(string version, bool isCommitHash, string? branch,
IProgress? progress = null)
{
return Task.FromResult(version);
diff --git a/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs b/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs
new file mode 100644
index 00000000..069d1368
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Packages/UnknownPackage.cs
@@ -0,0 +1,105 @@
+using Octokit;
+using StabilityMatrix.Core.Models.Database;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+
+namespace StabilityMatrix.Core.Models.Packages;
+
+public class UnknownPackage : BasePackage
+{
+ public static string Key => "unknown-package";
+ public override string Name => Key;
+ public override string DisplayName { get; set; } = "Unknown Package";
+ public override string Author => "";
+
+ public override string GithubUrl => "";
+ public override string LicenseType => "AGPL-3.0";
+ public override string LicenseUrl =>
+ "https://github.com/LykosAI/StabilityMatrix/blob/main/LICENSE";
+ public override string Blurb => "A dank interface for diffusion";
+ public override string LaunchCommand => "test";
+
+ public override Uri PreviewImageUri => new("");
+
+ public override IReadOnlyList ExtraLaunchCommands => new[]
+ {
+ "test-config",
+ };
+
+ ///
+ public override Task DownloadPackage(string version, bool isCommitHash, string? branch, IProgress