Browse Source

formatting

pull/117/head
jt 1 year ago
parent
commit
e0462b5bd1
  1. 10
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Controls/LogViewerControl.axaml.cs
  2. 16
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Converters/ChangeColorTypeConverter.cs
  3. 8
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Converters/EventIdConverter.cs
  4. 10
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Extensions/LoggerExtensions.cs
  5. 52
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/DataStoreLoggerConfiguration.cs
  6. 2
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/ILogDataStoreImpl.cs
  7. 2
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/LogDataStore.cs
  8. 7
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/LogEntryColor.cs
  9. 12
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/LogModel.cs
  10. 2
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/ViewModels/LogViewerControlViewModel.cs
  11. 13
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/ViewModels/ObservableObject.cs
  12. 4
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/ViewModels/ViewModel.cs
  13. 36
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/DataStoreLoggerTarget.cs
  14. 28
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Extensions/ServicesExtension.cs
  15. 4
      StabilityMatrix.Avalonia.Diagnostics/LogViewer/Logging/LogDataStore.cs
  16. 7
      StabilityMatrix.Avalonia.Diagnostics/ViewModels/LogWindowViewModel.cs
  17. 11
      StabilityMatrix.Avalonia.Diagnostics/Views/LogWindow.axaml.cs
  18. 3
      StabilityMatrix.Avalonia/App.axaml
  19. 128
      StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs
  20. 15
      StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs
  21. 4
      StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs
  22. 6
      StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs
  23. 2
      StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs
  24. 303
      StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs
  25. 80
      StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs
  26. 178
      StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs
  27. 24
      StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageModificationDialogViewModel.cs
  28. 93
      StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs
  29. 49
      StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs
  30. 64
      StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs
  31. 329
      StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs
  32. 182
      StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs
  33. 48
      StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs
  34. 322
      StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
  35. 18
      StabilityMatrix.Avalonia/Views/Dialogs/PackageModificationDialog.axaml.cs
  36. 16
      StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs
  37. 1
      StabilityMatrix.Avalonia/Views/MainWindow.axaml
  38. 3
      StabilityMatrix.Core/Database/ILiteDbContext.cs
  39. 115
      StabilityMatrix.Core/Database/LiteDbContext.cs
  40. 84
      StabilityMatrix.Core/Helper/SharedFolders.cs
  41. 3
      StabilityMatrix.Core/Models/Database/GitCommit.cs
  42. 7
      StabilityMatrix.Core/Models/Database/LocalModelFile.cs
  43. 75
      StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs
  44. 91
      StabilityMatrix.Core/Models/InstalledPackage.cs
  45. 13
      StabilityMatrix.Core/Models/InstalledPackageVersion.cs
  46. 6
      StabilityMatrix.Core/Models/PackageModification/AddInstalledPackageStep.cs
  47. 9
      StabilityMatrix.Core/Models/PackageModification/DownloadPackageVersionStep.cs
  48. 2
      StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs
  49. 6
      StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs
  50. 5
      StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs
  51. 12
      StabilityMatrix.Core/Models/PackageModification/SetupModelFoldersStep.cs
  52. 11
      StabilityMatrix.Core/Models/PackageModification/SetupPrerequisitesStep.cs
  53. 249
      StabilityMatrix.Core/Models/Packages/A3WebUI.cs
  54. 194
      StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
  55. 148
      StabilityMatrix.Core/Models/Packages/BasePackage.cs
  56. 227
      StabilityMatrix.Core/Models/Packages/ComfyUI.cs
  57. 45
      StabilityMatrix.Core/Models/Packages/DankDiffusion.cs
  58. 153
      StabilityMatrix.Core/Models/Packages/Fooocus.cs
  59. 270
      StabilityMatrix.Core/Models/Packages/InvokeAI.cs
  60. 84
      StabilityMatrix.Core/Models/Packages/UnknownPackage.cs
  61. 387
      StabilityMatrix.Core/Models/Packages/VladAutomatic.cs
  62. 226
      StabilityMatrix.Core/Models/Packages/VoltaML.cs
  63. 2
      StabilityMatrix.Core/Services/IModelIndexService.cs
  64. 11
      StabilityMatrix.Core/Services/ISettingsManager.cs
  65. 71
      StabilityMatrix.Core/Services/ModelIndexService.cs
  66. 259
      StabilityMatrix.Core/Services/SettingsManager.cs
  67. 98
      StabilityMatrix.Core/Services/TrackedDownloadService.cs

10
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Controls/LogViewerControl.axaml.cs

@ -8,8 +8,7 @@ namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Controls;
public partial class LogViewerControl : UserControl public partial class LogViewerControl : UserControl
{ {
public LogViewerControl() public LogViewerControl() => InitializeComponent();
=> InitializeComponent();
private ILogDataStoreImpl? vm; private ILogDataStoreImpl? vm;
private LogModel? item; private LogModel? item;
@ -17,7 +16,7 @@ public partial class LogViewerControl : UserControl
protected override void OnDataContextChanged(EventArgs e) protected override void OnDataContextChanged(EventArgs e)
{ {
base.OnDataContextChanged(e); base.OnDataContextChanged(e);
if (DataContext is null) if (DataContext is null)
return; return;
@ -45,8 +44,9 @@ public partial class LogViewerControl : UserControl
protected override void OnDetachedFromLogicalTree(LogicalTreeAttachmentEventArgs e) protected override void OnDetachedFromLogicalTree(LogicalTreeAttachmentEventArgs e)
{ {
base.OnDetachedFromLogicalTree(e); base.OnDetachedFromLogicalTree(e);
if (vm is null) return; if (vm is null)
return;
vm.DataStore.Entries.CollectionChanged -= OnCollectionChanged; vm.DataStore.Entries.CollectionChanged -= OnCollectionChanged;
} }
} }

16
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Converters/ChangeColorTypeConverter.cs

@ -13,13 +13,15 @@ public class ChangeColorTypeConverter : IValueConverter
return new SolidColorBrush((Color)(parameter ?? Colors.Black)); return new SolidColorBrush((Color)(parameter ?? Colors.Black));
var sysDrawColor = (SysDrawColor)value!; var sysDrawColor = (SysDrawColor)value!;
return new SolidColorBrush(Color.FromArgb( return new SolidColorBrush(
sysDrawColor.A, Color.FromArgb(sysDrawColor.A, sysDrawColor.R, sysDrawColor.G, sysDrawColor.B)
sysDrawColor.R, );
sysDrawColor.G,
sysDrawColor.B));
} }
public object ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture) public object ConvertBack(
=> throw new NotImplementedException(); object? value,
Type targetType,
object? parameter,
CultureInfo culture
) => throw new NotImplementedException();
} }

8
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Converters/EventIdConverter.cs

@ -17,6 +17,10 @@ public class EventIdConverter : IValueConverter
} }
// If not implemented, an error is thrown // If not implemented, an error is thrown
public object ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture) public object ConvertBack(
=> new EventId(0, value?.ToString() ?? string.Empty); object? value,
Type targetType,
object? parameter,
CultureInfo culture
) => new EventId(0, value?.ToString() ?? string.Empty);
} }

10
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Extensions/LoggerExtensions.cs

@ -4,8 +4,14 @@ namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.Extensions;
public static class LoggerExtensions public static class LoggerExtensions
{ {
public static void Emit(this ILogger logger, EventId eventId, public static void Emit(
LogLevel logLevel, string message, Exception? exception = null, params object?[] args) this ILogger logger,
EventId eventId,
LogLevel logLevel,
string message,
Exception? exception = null,
params object?[] args
)
{ {
if (logger is null) if (logger is null)
return; return;

52
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/DataStoreLoggerConfiguration.cs

@ -6,42 +6,28 @@ namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.Logging;
public class DataStoreLoggerConfiguration public class DataStoreLoggerConfiguration
{ {
#region Properties #region Properties
public EventId EventId { get; set; } public EventId EventId { get; set; }
public Dictionary<LogLevel, LogEntryColor> Colors { get; } = new() public Dictionary<LogLevel, LogEntryColor> Colors { get; } =
{ new()
[LogLevel.Trace] = new LogEntryColor
{
Foreground = Color.DarkGray
},
[LogLevel.Debug] = new LogEntryColor
{
Foreground = Color.Gray
},
[LogLevel.Information] = new LogEntryColor
{
Foreground = Color.WhiteSmoke,
},
[LogLevel.Warning] = new LogEntryColor
{
Foreground = Color.Orange
},
[LogLevel.Error] = new LogEntryColor
{
Foreground = Color.White,
Background = Color.OrangeRed
},
[LogLevel.Critical] = new LogEntryColor
{
Foreground = Color.White,
Background = Color.Red
},
[LogLevel.None] = new LogEntryColor
{ {
Foreground = Color.Magenta [LogLevel.Trace] = new LogEntryColor { Foreground = Color.DarkGray },
} [LogLevel.Debug] = new LogEntryColor { Foreground = Color.Gray },
}; [LogLevel.Information] = new LogEntryColor { Foreground = Color.WhiteSmoke, },
[LogLevel.Warning] = new LogEntryColor { Foreground = Color.Orange },
[LogLevel.Error] = new LogEntryColor
{
Foreground = Color.White,
Background = Color.OrangeRed
},
[LogLevel.Critical] = new LogEntryColor
{
Foreground = Color.White,
Background = Color.Red
},
[LogLevel.None] = new LogEntryColor { Foreground = Color.Magenta }
};
#endregion #endregion
} }

2
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/ILogDataStoreImpl.cs

@ -3,4 +3,4 @@
public interface ILogDataStoreImpl public interface ILogDataStoreImpl
{ {
public ILogDataStore DataStore { get; } public ILogDataStore DataStore { get; }
} }

2
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/LogDataStore.cs

@ -6,7 +6,7 @@ namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.Logging;
public class LogDataStore : ILogDataStore public class LogDataStore : ILogDataStore
{ {
public static LogDataStore Instance { get; } = new(); public static LogDataStore Instance { get; } = new();
#region Fields #region Fields
private static readonly SemaphoreSlim _semaphore = new(initialCount: 1); private static readonly SemaphoreSlim _semaphore = new(initialCount: 1);

7
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/LogEntryColor.cs

@ -4,9 +4,7 @@ namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.Logging;
public class LogEntryColor public class LogEntryColor
{ {
public LogEntryColor() public LogEntryColor() { }
{
}
public LogEntryColor(Color foreground, Color background) public LogEntryColor(Color foreground, Color background)
{ {
@ -16,5 +14,4 @@ public class LogEntryColor
public Color Foreground { get; set; } = Color.Black; public Color Foreground { get; set; } = Color.Black;
public Color Background { get; set; } = Color.Transparent; public Color Background { get; set; } = Color.Transparent;
}
}

12
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/Logging/LogModel.cs

@ -13,11 +13,11 @@ public class LogModel
public EventId EventId { get; set; } public EventId EventId { get; set; }
public object? State { get; set; } public object? State { get; set; }
public string? LoggerName { get; set; } public string? LoggerName { get; set; }
public string? CallerClassName { get; set; } public string? CallerClassName { get; set; }
public string? CallerMemberName { get; set; } public string? CallerMemberName { get; set; }
public string? Exception { get; set; } public string? Exception { get; set; }
@ -26,8 +26,6 @@ public class LogModel
#endregion #endregion
public string LoggerDisplayName => public string LoggerDisplayName =>
LoggerName? LoggerName?.Split('.', StringSplitOptions.RemoveEmptyEntries).LastOrDefault() ?? "";
.Split('.', StringSplitOptions.RemoveEmptyEntries)
.LastOrDefault() ?? "";
} }

2
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/ViewModels/LogViewerControlViewModel.cs

@ -18,4 +18,4 @@ public class LogViewerControlViewModel : ViewModel, ILogDataStoreImpl
public ILogDataStore DataStore { get; set; } public ILogDataStore DataStore { get; set; }
#endregion #endregion
} }

13
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/ViewModels/ObservableObject.cs

@ -5,9 +5,14 @@ namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.ViewModels;
public class ObservableObject : INotifyPropertyChanged public class ObservableObject : INotifyPropertyChanged
{ {
protected bool Set<TValue>(ref TValue field, TValue newValue, [CallerMemberName] string? propertyName = null) protected bool Set<TValue>(
ref TValue field,
TValue newValue,
[CallerMemberName] string? propertyName = null
)
{ {
if (EqualityComparer<TValue>.Default.Equals(field, newValue)) return false; if (EqualityComparer<TValue>.Default.Equals(field, newValue))
return false;
field = newValue; field = newValue;
OnPropertyChanged(propertyName); OnPropertyChanged(propertyName);
@ -16,6 +21,6 @@ public class ObservableObject : INotifyPropertyChanged
public event PropertyChangedEventHandler? PropertyChanged; public event PropertyChangedEventHandler? PropertyChanged;
protected virtual void OnPropertyChanged([CallerMemberName] string? propertyName = null) protected virtual void OnPropertyChanged([CallerMemberName] string? propertyName = null) =>
=> PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName)); PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName));
} }

4
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Core/ViewModels/ViewModel.cs

@ -1,3 +1,5 @@
namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.ViewModels; namespace StabilityMatrix.Avalonia.Diagnostics.LogViewer.Core.ViewModels;
public class ViewModel : ObservableObject { /* skip */ } public class ViewModel
: ObservableObject { /* skip */
}

36
StabilityMatrix.Avalonia.Diagnostics/LogViewer/DataStoreLoggerTarget.cs

@ -54,21 +54,27 @@ public class DataStoreLoggerTarget : TargetWithLayout
} }
// add log entry // add log entry
_dataStore?.AddEntry(new LogModel _dataStore?.AddEntry(
{ new LogModel
Timestamp = DateTime.UtcNow, {
LogLevel = logLevel, Timestamp = DateTime.UtcNow,
// do we override the default EventId if it exists? LogLevel = logLevel,
EventId = eventId.Id == 0 && (_config?.EventId.Id ?? 0) != 0 ? _config!.EventId : eventId, // do we override the default EventId if it exists?
State = message, EventId =
LoggerName = logEvent.LoggerName, eventId.Id == 0 && (_config?.EventId.Id ?? 0) != 0 ? _config!.EventId : eventId,
CallerClassName = logEvent.CallerClassName, State = message,
CallerMemberName = logEvent.CallerMemberName, LoggerName = logEvent.LoggerName,
Exception = logEvent.Exception?.Message ?? (logLevel == MsLogLevel.Error ? message : ""), CallerClassName = logEvent.CallerClassName,
Color = _config!.Colors[logLevel], CallerMemberName = logEvent.CallerMemberName,
}); Exception =
logEvent.Exception?.Message ?? (logLevel == MsLogLevel.Error ? message : ""),
Debug.WriteLine($"--- [{logLevel.ToString()[..3]}] {message} - {logEvent.Exception?.Message ?? "no error"}"); Color = _config!.Colors[logLevel],
}
);
Debug.WriteLine(
$"--- [{logLevel.ToString()[..3]}] {message} - {logEvent.Exception?.Message ?? "no error"}"
);
} }
#endregion #endregion

28
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Extensions/ServicesExtension.cs

@ -20,25 +20,31 @@ public static class ServicesExtension
return services; return services;
} }
public static IServiceCollection AddLogViewer( public static IServiceCollection AddLogViewer(
this IServiceCollection services, this IServiceCollection services,
Action<DataStoreLoggerConfiguration> configure) Action<DataStoreLoggerConfiguration> configure
)
{ {
services.AddSingleton<ILogDataStore>(Core.Logging.LogDataStore.Instance); services.AddSingleton<ILogDataStore>(Core.Logging.LogDataStore.Instance);
services.AddSingleton<LogViewerControlViewModel>(); services.AddSingleton<LogViewerControlViewModel>();
services.Configure(configure); services.Configure(configure);
return services; return services;
} }
public static ILoggingBuilder AddNLogTargets(this ILoggingBuilder builder, IConfiguration config) public static ILoggingBuilder AddNLogTargets(
this ILoggingBuilder builder,
IConfiguration config
)
{ {
LogManager LogManager
.Setup() .Setup()
// Register custom Target // Register custom Target
.SetupExtensions(extensionBuilder => .SetupExtensions(
extensionBuilder.RegisterTarget<DataStoreLoggerTarget>("DataStoreLogger")); extensionBuilder =>
extensionBuilder.RegisterTarget<DataStoreLoggerTarget>("DataStoreLogger")
);
/*builder /*builder
.ClearProviders() .ClearProviders()
@ -56,7 +62,11 @@ public static class ServicesExtension
return builder; return builder;
} }
public static ILoggingBuilder AddNLogTargets(this ILoggingBuilder builder, IConfiguration config, Action<DataStoreLoggerConfiguration> configure) public static ILoggingBuilder AddNLogTargets(
this ILoggingBuilder builder,
IConfiguration config,
Action<DataStoreLoggerConfiguration> configure
)
{ {
builder.AddNLogTargets(config); builder.AddNLogTargets(config);
builder.Services.Configure(configure); builder.Services.Configure(configure);

4
StabilityMatrix.Avalonia.Diagnostics/LogViewer/Logging/LogDataStore.cs

@ -6,8 +6,8 @@ public class LogDataStore : Core.Logging.LogDataStore
{ {
#region Methods #region Methods
public override async void AddEntry(Core.Logging.LogModel logModel) public override async void AddEntry(Core.Logging.LogModel logModel) =>
=> await Dispatcher.UIThread.InvokeAsync(() => base.AddEntry(logModel)); await Dispatcher.UIThread.InvokeAsync(() => base.AddEntry(logModel));
#endregion #endregion
} }

7
StabilityMatrix.Avalonia.Diagnostics/ViewModels/LogWindowViewModel.cs

@ -6,15 +6,14 @@ namespace StabilityMatrix.Avalonia.Diagnostics.ViewModels;
public class LogWindowViewModel public class LogWindowViewModel
{ {
public LogViewerControlViewModel LogViewer { get; } public LogViewerControlViewModel LogViewer { get; }
public LogWindowViewModel(LogViewerControlViewModel logViewer) public LogWindowViewModel(LogViewerControlViewModel logViewer)
{ {
LogViewer = logViewer; LogViewer = logViewer;
} }
public static LogWindowViewModel FromServiceProvider(IServiceProvider services) public static LogWindowViewModel FromServiceProvider(IServiceProvider services)
{ {
return new LogWindowViewModel( return new LogWindowViewModel(services.GetRequiredService<LogViewerControlViewModel>());
services.GetRequiredService<LogViewerControlViewModel>());
} }
} }

11
StabilityMatrix.Avalonia.Diagnostics/Views/LogWindow.axaml.cs

@ -16,13 +16,18 @@ public partial class LogWindow : Window
{ {
return Attach(root, serviceProvider, new KeyGesture(Key.F11)); return Attach(root, serviceProvider, new KeyGesture(Key.F11));
} }
public static IDisposable Attach(TopLevel root, IServiceProvider serviceProvider, KeyGesture gesture) public static IDisposable Attach(
TopLevel root,
IServiceProvider serviceProvider,
KeyGesture gesture
)
{ {
return (root ?? throw new ArgumentNullException(nameof(root))).AddDisposableHandler( return (root ?? throw new ArgumentNullException(nameof(root))).AddDisposableHandler(
KeyDownEvent, KeyDownEvent,
PreviewKeyDown, PreviewKeyDown,
RoutingStrategies.Tunnel); RoutingStrategies.Tunnel
);
void PreviewKeyDown(object? sender, KeyEventArgs e) void PreviewKeyDown(object? sender, KeyEventArgs e)
{ {

3
StabilityMatrix.Avalonia/App.axaml

@ -22,7 +22,8 @@
</Application.Resources> </Application.Resources>
<Application.Styles> <Application.Styles>
<styling:FluentAvaloniaTheme PreferUserAccentColor="True" UseSystemFontOnWindows="True" /> <styling:FluentAvaloniaTheme PreferUserAccentColor="True" UseSystemFontOnWindows="True"
TextVerticalAlignmentOverrideBehavior="Disabled"/>
<StyleInclude Source="avares://AvaloniaEdit/Themes/Fluent/AvaloniaEdit.xaml"/> <StyleInclude Source="avares://AvaloniaEdit/Themes/Fluent/AvaloniaEdit.xaml"/>
<StyleInclude Source="avares://AsyncImageLoader.Avalonia/AdvancedImage.axaml" /> <StyleInclude Source="avares://AsyncImageLoader.Avalonia/AdvancedImage.axaml" />
<StyleInclude Source="Styles/ProgressRing.axaml"/> <StyleInclude Source="Styles/ProgressRing.axaml"/>

128
StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs

@ -20,54 +20,64 @@ public class BetterContentDialog : ContentDialog
#region Reflection Shenanigans for setting content dialog result #region Reflection Shenanigans for setting content dialog result
[NotNull] [NotNull]
protected static readonly FieldInfo? ResultField = typeof(ContentDialog).GetField( protected static readonly FieldInfo? ResultField = typeof(ContentDialog).GetField(
"_result",BindingFlags.Instance | BindingFlags.NonPublic); "_result",
BindingFlags.Instance | BindingFlags.NonPublic
);
protected ContentDialogResult Result protected ContentDialogResult Result
{ {
get => (ContentDialogResult) ResultField.GetValue(this)!; get => (ContentDialogResult)ResultField.GetValue(this)!;
set => ResultField.SetValue(this, value); set => ResultField.SetValue(this, value);
} }
[NotNull] [NotNull]
protected static readonly MethodInfo? HideCoreMethod = typeof(ContentDialog).GetMethod( protected static readonly MethodInfo? HideCoreMethod = typeof(ContentDialog).GetMethod(
"HideCore", BindingFlags.Instance | BindingFlags.NonPublic); "HideCore",
BindingFlags.Instance | BindingFlags.NonPublic
);
protected void HideCore() protected void HideCore()
{ {
HideCoreMethod.Invoke(this, null); HideCoreMethod.Invoke(this, null);
} }
// Also get button properties to hide on command execution change // Also get button properties to hide on command execution change
[NotNull] [NotNull]
protected static readonly FieldInfo? PrimaryButtonField = typeof(ContentDialog).GetField( protected static readonly FieldInfo? PrimaryButtonField = typeof(ContentDialog).GetField(
"_primaryButton", BindingFlags.Instance | BindingFlags.NonPublic); "_primaryButton",
BindingFlags.Instance | BindingFlags.NonPublic
);
protected Button? PrimaryButton protected Button? PrimaryButton
{ {
get => (Button?) PrimaryButtonField.GetValue(this)!; get => (Button?)PrimaryButtonField.GetValue(this)!;
set => PrimaryButtonField.SetValue(this, value); set => PrimaryButtonField.SetValue(this, value);
} }
[NotNull] [NotNull]
protected static readonly FieldInfo? SecondaryButtonField = typeof(ContentDialog).GetField( protected static readonly FieldInfo? SecondaryButtonField = typeof(ContentDialog).GetField(
"_secondaryButton", BindingFlags.Instance | BindingFlags.NonPublic); "_secondaryButton",
BindingFlags.Instance | BindingFlags.NonPublic
protected Button? SecondaryButton );
protected Button? SecondaryButton
{ {
get => (Button?) SecondaryButtonField.GetValue(this)!; get => (Button?)SecondaryButtonField.GetValue(this)!;
set => SecondaryButtonField.SetValue(this, value); set => SecondaryButtonField.SetValue(this, value);
} }
[NotNull] [NotNull]
protected static readonly FieldInfo? CloseButtonField = typeof(ContentDialog).GetField( protected static readonly FieldInfo? CloseButtonField = typeof(ContentDialog).GetField(
"_closeButton", BindingFlags.Instance | BindingFlags.NonPublic); "_closeButton",
BindingFlags.Instance | BindingFlags.NonPublic
protected Button? CloseButton );
protected Button? CloseButton
{ {
get => (Button?) CloseButtonField.GetValue(this)!; get => (Button?)CloseButtonField.GetValue(this)!;
set => CloseButtonField.SetValue(this, value); set => CloseButtonField.SetValue(this, value);
} }
static BetterContentDialog() static BetterContentDialog()
{ {
if (ResultField is null) if (ResultField is null)
@ -84,11 +94,13 @@ public class BetterContentDialog : ContentDialog
} }
} }
#endregion #endregion
protected override Type StyleKeyOverride { get; } = typeof(ContentDialog); protected override Type StyleKeyOverride { get; } = typeof(ContentDialog);
public static readonly StyledProperty<bool> IsFooterVisibleProperty = AvaloniaProperty.Register<BetterContentDialog, bool>( public static readonly StyledProperty<bool> IsFooterVisibleProperty = AvaloniaProperty.Register<
"IsFooterVisible", true); BetterContentDialog,
bool
>("IsFooterVisible", true);
public bool IsFooterVisible public bool IsFooterVisible
{ {
@ -96,9 +108,11 @@ public class BetterContentDialog : ContentDialog
set => SetValue(IsFooterVisibleProperty, value); set => SetValue(IsFooterVisibleProperty, value);
} }
public static readonly StyledProperty<ScrollBarVisibility> ContentVerticalScrollBarVisibilityProperty public static readonly StyledProperty<ScrollBarVisibility> ContentVerticalScrollBarVisibilityProperty =
= AvaloniaProperty.Register<BetterContentDialog, ScrollBarVisibility>( AvaloniaProperty.Register<BetterContentDialog, ScrollBarVisibility>(
"ContentScrollBarVisibility", ScrollBarVisibility.Auto); "ContentScrollBarVisibility",
ScrollBarVisibility.Auto
);
public ScrollBarVisibility ContentVerticalScrollBarVisibility public ScrollBarVisibility ContentVerticalScrollBarVisibility
{ {
@ -106,17 +120,17 @@ public class BetterContentDialog : ContentDialog
set => SetValue(ContentVerticalScrollBarVisibilityProperty, value); set => SetValue(ContentVerticalScrollBarVisibilityProperty, value);
} }
public static readonly StyledProperty<double> MinDialogWidthProperty = AvaloniaProperty.Register<BetterContentDialog, double>( public static readonly StyledProperty<double> MinDialogWidthProperty =
"MinDialogWidth"); AvaloniaProperty.Register<BetterContentDialog, double>("MinDialogWidth");
public double MinDialogWidth public double MinDialogWidth
{ {
get => GetValue(MinDialogWidthProperty); get => GetValue(MinDialogWidthProperty);
set => SetValue(MinDialogWidthProperty, value); set => SetValue(MinDialogWidthProperty, value);
} }
public static readonly StyledProperty<double> MaxDialogWidthProperty = AvaloniaProperty.Register<BetterContentDialog, double>( public static readonly StyledProperty<double> MaxDialogWidthProperty =
"MaxDialogWidth"); AvaloniaProperty.Register<BetterContentDialog, double>("MaxDialogWidth");
public double MaxDialogWidth public double MaxDialogWidth
{ {
@ -124,8 +138,8 @@ public class BetterContentDialog : ContentDialog
set => SetValue(MaxDialogWidthProperty, value); set => SetValue(MaxDialogWidthProperty, value);
} }
public static readonly StyledProperty<double> MaxDialogHeightProperty = AvaloniaProperty.Register<BetterContentDialog, double>( public static readonly StyledProperty<double> MaxDialogHeightProperty =
"MaxDialogHeight"); AvaloniaProperty.Register<BetterContentDialog, double>("MaxDialogHeight");
public double MaxDialogHeight public double MaxDialogHeight
{ {
@ -133,15 +147,14 @@ public class BetterContentDialog : ContentDialog
set => SetValue(MaxDialogHeightProperty, value); set => SetValue(MaxDialogHeightProperty, value);
} }
public static readonly StyledProperty<Thickness> ContentMarginProperty = AvaloniaProperty.Register<BetterContentDialog, Thickness>( public static readonly StyledProperty<Thickness> ContentMarginProperty =
"ContentMargin"); AvaloniaProperty.Register<BetterContentDialog, Thickness>("ContentMargin");
public Thickness ContentMargin public Thickness ContentMargin
{ {
get => GetValue(ContentMarginProperty); get => GetValue(ContentMarginProperty);
set => SetValue(ContentMarginProperty, value); set => SetValue(ContentMarginProperty, value);
} }
public BetterContentDialog() public BetterContentDialog()
{ {
@ -156,13 +169,16 @@ public class BetterContentDialog : ContentDialog
viewModel.SecondaryButtonClick += OnDialogButtonClick; viewModel.SecondaryButtonClick += OnDialogButtonClick;
viewModel.CloseButtonClick += OnDialogButtonClick; viewModel.CloseButtonClick += OnDialogButtonClick;
} }
else if ((Content as Control)?.DataContext is ContentDialogProgressViewModelBase progressViewModel) else if (
(Content as Control)?.DataContext
is ContentDialogProgressViewModelBase progressViewModel
)
{ {
progressViewModel.PrimaryButtonClick += OnDialogButtonClick; progressViewModel.PrimaryButtonClick += OnDialogButtonClick;
progressViewModel.SecondaryButtonClick += OnDialogButtonClick; progressViewModel.SecondaryButtonClick += OnDialogButtonClick;
progressViewModel.CloseButtonClick += OnDialogButtonClick; progressViewModel.CloseButtonClick += OnDialogButtonClick;
} }
// If commands provided, bind OnCanExecuteChanged to hide buttons // If commands provided, bind OnCanExecuteChanged to hide buttons
// otherwise link visibility to IsEnabled // otherwise link visibility to IsEnabled
if (PrimaryButton is not null) if (PrimaryButton is not null)
@ -176,10 +192,11 @@ public class BetterContentDialog : ContentDialog
} }
else else
{ {
PrimaryButton.IsVisible = IsPrimaryButtonEnabled && !string.IsNullOrEmpty(PrimaryButtonText); PrimaryButton.IsVisible =
IsPrimaryButtonEnabled && !string.IsNullOrEmpty(PrimaryButtonText);
} }
} }
if (SecondaryButton is not null) if (SecondaryButton is not null)
{ {
if (SecondaryButtonCommand is not null) if (SecondaryButtonCommand is not null)
@ -191,10 +208,11 @@ public class BetterContentDialog : ContentDialog
} }
else else
{ {
SecondaryButton.IsVisible = IsSecondaryButtonEnabled && !string.IsNullOrEmpty(SecondaryButtonText); SecondaryButton.IsVisible =
IsSecondaryButtonEnabled && !string.IsNullOrEmpty(SecondaryButtonText);
} }
} }
if (CloseButton is not null) if (CloseButton is not null)
{ {
if (CloseButtonCommand is not null) if (CloseButtonCommand is not null)
@ -216,7 +234,7 @@ public class BetterContentDialog : ContentDialog
protected override void OnDataContextChanged(EventArgs e) protected override void OnDataContextChanged(EventArgs e)
{ {
base.OnDataContextChanged(e); base.OnDataContextChanged(e);
TryBindButtons(); TryBindButtons();
} }
@ -242,7 +260,7 @@ public class BetterContentDialog : ContentDialog
var border = VisualChildren[0] as Border; var border = VisualChildren[0] as Border;
var panel = border?.Child as Panel; var panel = border?.Child as Panel;
var faBorder = panel?.Children[0] as FABorder; var faBorder = panel?.Children[0] as FABorder;
// Set dialog bounds // Set dialog bounds
if (MaxDialogWidth > 0) if (MaxDialogWidth > 0)
{ {
@ -257,36 +275,38 @@ public class BetterContentDialog : ContentDialog
{ {
faBorder!.MaxHeight = MaxDialogHeight; faBorder!.MaxHeight = MaxDialogHeight;
} }
var border2 = faBorder?.Child as Border; var border2 = faBorder?.Child as Border;
// Named Grid 'DialogSpace' // Named Grid 'DialogSpace'
if (border2?.Child is not Grid dialogSpaceGrid) throw new InvalidOperationException("Could not find DialogSpace grid"); if (border2?.Child is not Grid dialogSpaceGrid)
throw new InvalidOperationException("Could not find DialogSpace grid");
var scrollViewer = dialogSpaceGrid.Children[0] as ScrollViewer; var scrollViewer = dialogSpaceGrid.Children[0] as ScrollViewer;
var actualBorder = dialogSpaceGrid.Children[1] as Border; var actualBorder = dialogSpaceGrid.Children[1] as Border;
// Get the parent border, which is what we want to hide // Get the parent border, which is what we want to hide
if (scrollViewer is null || actualBorder is null) if (scrollViewer is null || actualBorder is null)
{ {
throw new InvalidOperationException("Could not find parent border"); throw new InvalidOperationException("Could not find parent border");
} }
var subBorder = scrollViewer.Content as Border; var subBorder = scrollViewer.Content as Border;
var subGrid = subBorder?.Child as Grid; var subGrid = subBorder?.Child as Grid;
if (subGrid is null) throw new InvalidOperationException("Could not find sub grid"); if (subGrid is null)
throw new InvalidOperationException("Could not find sub grid");
var contentControlTitle = subGrid.Children[0] as ContentControl; var contentControlTitle = subGrid.Children[0] as ContentControl;
// Hide title if empty // Hide title if empty
if (Title is null or string {Length: 0}) if (Title is null or string { Length: 0 })
{ {
contentControlTitle!.IsVisible = false; contentControlTitle!.IsVisible = false;
} }
// Set footer and scrollbar visibility states // Set footer and scrollbar visibility states
actualBorder.IsVisible = IsFooterVisible; actualBorder.IsVisible = IsFooterVisible;
scrollViewer.VerticalScrollBarVisibility = ContentVerticalScrollBarVisibility; scrollViewer.VerticalScrollBarVisibility = ContentVerticalScrollBarVisibility;
// Also call the vm's OnLoad // Also call the vm's OnLoad
if (Content is Control {DataContext: ViewModelBase viewModel}) if (Content is Control { DataContext: ViewModelBase viewModel })
{ {
viewModel.OnLoaded(); viewModel.OnLoaded();
Dispatcher.UIThread.InvokeAsync(viewModel.OnLoadedAsync).SafeFireAndForget(); Dispatcher.UIThread.InvokeAsync(viewModel.OnLoadedAsync).SafeFireAndForget();

15
StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs

@ -12,11 +12,16 @@ public class MockLiteDbContext : ILiteDbContext
{ {
public LiteDatabaseAsync Database => throw new NotImplementedException(); public LiteDatabaseAsync Database => throw new NotImplementedException();
public ILiteCollectionAsync<CivitModel> CivitModels => throw new NotImplementedException(); public ILiteCollectionAsync<CivitModel> CivitModels => throw new NotImplementedException();
public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions => throw new NotImplementedException(); public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions =>
public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache => throw new NotImplementedException(); throw new NotImplementedException();
public ILiteCollectionAsync<LocalModelFile> LocalModelFiles => throw new NotImplementedException(); public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache =>
throw new NotImplementedException();
public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(string hashBlake3) public ILiteCollectionAsync<LocalModelFile> LocalModelFiles =>
throw new NotImplementedException();
public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(
string hashBlake3
)
{ {
return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null)); return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null));
} }

4
StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs

@ -21,7 +21,5 @@ public class MockModelIndexService : IModelIndexService
} }
/// <inheritdoc /> /// <inheritdoc />
public void BackgroundRefreshIndex() public void BackgroundRefreshIndex() { }
{
}
} }

6
StabilityMatrix.Avalonia/ViewModels/Base/ContentDialogProgressViewModelBase.cs

@ -8,17 +8,17 @@ public class ContentDialogProgressViewModelBase : ProgressViewModel
public event EventHandler<ContentDialogResult>? PrimaryButtonClick; public event EventHandler<ContentDialogResult>? PrimaryButtonClick;
public event EventHandler<ContentDialogResult>? SecondaryButtonClick; public event EventHandler<ContentDialogResult>? SecondaryButtonClick;
public event EventHandler<ContentDialogResult>? CloseButtonClick; public event EventHandler<ContentDialogResult>? CloseButtonClick;
public virtual void OnPrimaryButtonClick() public virtual void OnPrimaryButtonClick()
{ {
PrimaryButtonClick?.Invoke(this, ContentDialogResult.Primary); PrimaryButtonClick?.Invoke(this, ContentDialogResult.Primary);
} }
public virtual void OnSecondaryButtonClick() public virtual void OnSecondaryButtonClick()
{ {
SecondaryButtonClick?.Invoke(this, ContentDialogResult.Secondary); SecondaryButtonClick?.Invoke(this, ContentDialogResult.Secondary);
} }
public virtual void OnCloseButtonClick() public virtual void OnCloseButtonClick()
{ {
CloseButtonClick?.Invoke(this, ContentDialogResult.None); CloseButtonClick?.Invoke(this, ContentDialogResult.None);

2
StabilityMatrix.Avalonia/ViewModels/Dialogs/EnvVarsViewModel.cs

@ -20,7 +20,7 @@ public partial class EnvVarsViewModel : ContentDialogViewModelBase
private ObservableCollection<EnvVarKeyPair> envVars = new(); private ObservableCollection<EnvVarKeyPair> envVars = new();
public DataGridCollectionView EnvVarsView => new(EnvVars); public DataGridCollectionView EnvVarsView => new(EnvVars);
[RelayCommand] [RelayCommand]
private void AddRow() private void AddRow()
{ {

303
StabilityMatrix.Avalonia/ViewModels/Dialogs/InstallerViewModel.cs

@ -31,55 +31,83 @@ using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
public partial class InstallerViewModel : ContentDialogViewModelBase public partial class InstallerViewModel : ContentDialogViewModelBase
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly IPyRunner pyRunner; private readonly IPyRunner pyRunner;
private readonly IDownloadService downloadService; private readonly IDownloadService downloadService;
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private readonly IPrerequisiteHelper prerequisiteHelper; private readonly IPrerequisiteHelper prerequisiteHelper;
[ObservableProperty] private BasePackage selectedPackage; [ObservableProperty]
[ObservableProperty] private PackageVersion? selectedVersion; private BasePackage selectedPackage;
[ObservableProperty] private IReadOnlyList<BasePackage>? availablePackages;
[ObservableProperty] private ObservableCollection<GitCommit>? availableCommits; [ObservableProperty]
[ObservableProperty] private ObservableCollection<PackageVersion>? availableVersions; private PackageVersion? selectedVersion;
[ObservableProperty] private GitCommit? selectedCommit;
[ObservableProperty] private string? releaseNotes; [ObservableProperty]
[ObservableProperty] private string latestVersionText = string.Empty; private IReadOnlyList<BasePackage>? availablePackages;
[ObservableProperty] private bool isAdvancedMode;
[ObservableProperty] private bool showDuplicateWarning; [ObservableProperty]
[ObservableProperty] private string? installName; private ObservableCollection<GitCommit>? availableCommits;
[ObservableProperty] private SharedFolderMethod selectedSharedFolderMethod;
[ObservableProperty]
private ObservableCollection<PackageVersion>? availableVersions;
[ObservableProperty]
private GitCommit? selectedCommit;
[ObservableProperty]
private string? releaseNotes;
[ObservableProperty]
private string latestVersionText = string.Empty;
[ObservableProperty]
private bool isAdvancedMode;
[ObservableProperty]
private bool showDuplicateWarning;
[ObservableProperty]
private string? installName;
[ObservableProperty]
private SharedFolderMethod selectedSharedFolderMethod;
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(ShowTorchVersionOptions))] [NotifyPropertyChangedFor(nameof(ShowTorchVersionOptions))]
private TorchVersion selectedTorchVersion; private TorchVersion selectedTorchVersion;
// Version types (release or commit) // Version types (release or commit)
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(ReleaseLabelText), [NotifyPropertyChangedFor(
nameof(IsReleaseMode), nameof(SelectedVersion))] nameof(ReleaseLabelText),
nameof(IsReleaseMode),
nameof(SelectedVersion)
)]
private PackageVersionType selectedVersionType = PackageVersionType.Commit; private PackageVersionType selectedVersionType = PackageVersionType.Commit;
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsReleaseModeAvailable))] [NotifyPropertyChangedFor(nameof(IsReleaseModeAvailable))]
private PackageVersionType availableVersionTypes = private PackageVersionType availableVersionTypes =
PackageVersionType.GithubRelease | PackageVersionType.Commit; PackageVersionType.GithubRelease | PackageVersionType.Commit;
public string ReleaseLabelText => IsReleaseMode ? "Version" : "Branch"; public string ReleaseLabelText => IsReleaseMode ? "Version" : "Branch";
public bool IsReleaseMode public bool IsReleaseMode
{ {
get => SelectedVersionType == PackageVersionType.GithubRelease; get => SelectedVersionType == PackageVersionType.GithubRelease;
set => SelectedVersionType = value ? PackageVersionType.GithubRelease : PackageVersionType.Commit; set =>
SelectedVersionType = value
? PackageVersionType.GithubRelease
: PackageVersionType.Commit;
} }
public bool IsReleaseModeAvailable => AvailableVersionTypes.HasFlag(PackageVersionType.GithubRelease); public bool IsReleaseModeAvailable =>
AvailableVersionTypes.HasFlag(PackageVersionType.GithubRelease);
public bool ShowTorchVersionOptions => SelectedTorchVersion != TorchVersion.None; public bool ShowTorchVersionOptions => SelectedTorchVersion != TorchVersion.None;
public ProgressViewModel InstallProgress { get; } = new(); public ProgressViewModel InstallProgress { get; } = new();
public IEnumerable<IPackageStep> Steps { get; set; } public IEnumerable<IPackageStep> Steps { get; set; }
@ -87,8 +115,10 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
ISettingsManager settingsManager, ISettingsManager settingsManager,
IPackageFactory packageFactory, IPackageFactory packageFactory,
IPyRunner pyRunner, IPyRunner pyRunner,
IDownloadService downloadService, INotificationService notificationService, IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper) INotificationService notificationService,
IPrerequisiteHelper prerequisiteHelper
)
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.pyRunner = pyRunner; this.pyRunner = pyRunner;
@ -97,36 +127,43 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
this.prerequisiteHelper = prerequisiteHelper; this.prerequisiteHelper = prerequisiteHelper;
// AvailablePackages and SelectedPackage // AvailablePackages and SelectedPackage
AvailablePackages = new ObservableCollection<BasePackage>(packageFactory.GetAllAvailablePackages()); AvailablePackages = new ObservableCollection<BasePackage>(
packageFactory.GetAllAvailablePackages()
);
SelectedPackage = AvailablePackages[0]; SelectedPackage = AvailablePackages[0];
} }
public override void OnLoaded() public override void OnLoaded()
{ {
if (AvailablePackages == null) return; if (AvailablePackages == null)
return;
SelectedPackage = AvailablePackages[0]; SelectedPackage = AvailablePackages[0];
IsReleaseMode = !SelectedPackage.ShouldIgnoreReleases; IsReleaseMode = !SelectedPackage.ShouldIgnoreReleases;
} }
public override async Task OnLoadedAsync() public override async Task OnLoadedAsync()
{ {
if (Design.IsDesignMode) return; if (Design.IsDesignMode)
return;
// Check for updates // Check for updates
try try
{ {
var versionOptions = await SelectedPackage.GetAllVersionOptions(); var versionOptions = await SelectedPackage.GetAllVersionOptions();
if (IsReleaseMode) if (IsReleaseMode)
{ {
AvailableVersions = AvailableVersions = new ObservableCollection<PackageVersion>(
new ObservableCollection<PackageVersion>(versionOptions.AvailableVersions); versionOptions.AvailableVersions
if (!AvailableVersions.Any()) return; );
if (!AvailableVersions.Any())
return;
SelectedVersion = AvailableVersions.First(x => !x.IsPrerelease); SelectedVersion = AvailableVersions.First(x => !x.IsPrerelease);
} }
else else
{ {
AvailableVersions = AvailableVersions = new ObservableCollection<PackageVersion>(
new ObservableCollection<PackageVersion>(versionOptions.AvailableBranches); versionOptions.AvailableBranches
);
UpdateSelectedVersionToLatestMain(); UpdateSelectedVersionToLatestMain();
} }
@ -137,14 +174,17 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
Logger.Warn("Error getting versions: {Exception}", e.ToString()); Logger.Warn("Error getting versions: {Exception}", e.ToString());
} }
} }
[RelayCommand] [RelayCommand]
private async Task Install() private async Task Install()
{ {
var result = await notificationService.TryAsync(ActuallyInstall(), "Could not install package"); var result = await notificationService.TryAsync(
ActuallyInstall(),
"Could not install package"
);
if (result.IsSuccessful) if (result.IsSuccessful)
{ {
OnPrimaryButtonClick(); OnPrimaryButtonClick();
} }
else else
{ {
@ -160,41 +200,58 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
await dialog.ShowAsync(); await dialog.ShowAsync();
} }
} }
private Task ActuallyInstall() private Task ActuallyInstall()
{ {
if (string.IsNullOrWhiteSpace(InstallName)) if (string.IsNullOrWhiteSpace(InstallName))
{ {
notificationService.Show(new Notification("Package name is empty", notificationService.Show(
"Please enter a name for the package", NotificationType.Error)); new Notification(
"Package name is empty",
"Please enter a name for the package",
NotificationType.Error
)
);
return Task.CompletedTask; return Task.CompletedTask;
} }
var installLocation = Path.Combine(settingsManager.LibraryDir, "Packages", InstallName); var installLocation = Path.Combine(settingsManager.LibraryDir, "Packages", InstallName);
var prereqStep = new SetupPrerequisitesStep(prerequisiteHelper, pyRunner); var prereqStep = new SetupPrerequisitesStep(prerequisiteHelper, pyRunner);
var downloadOptions = new DownloadPackageVersionOptions(); var downloadOptions = new DownloadPackageVersionOptions();
var installedVersion = new InstalledPackageVersion(); var installedVersion = new InstalledPackageVersion();
if (IsReleaseMode) if (IsReleaseMode)
{ {
downloadOptions.VersionTag = SelectedVersion?.TagName ?? downloadOptions.VersionTag =
throw new NullReferenceException("Selected version is null"); SelectedVersion?.TagName
?? throw new NullReferenceException("Selected version is null");
installedVersion.InstalledReleaseVersion = downloadOptions.VersionTag; installedVersion.InstalledReleaseVersion = downloadOptions.VersionTag;
} }
else else
{ {
downloadOptions.CommitHash = SelectedCommit?.Sha ?? downloadOptions.CommitHash =
throw new NullReferenceException("Selected commit is null"); SelectedCommit?.Sha ?? throw new NullReferenceException("Selected commit is null");
installedVersion.InstalledBranch = SelectedVersion?.TagName ?? installedVersion.InstalledBranch =
throw new NullReferenceException("Selected version is null"); SelectedVersion?.TagName
?? throw new NullReferenceException("Selected version is null");
installedVersion.InstalledCommitSha = downloadOptions.CommitHash; installedVersion.InstalledCommitSha = downloadOptions.CommitHash;
} }
var downloadStep = var downloadStep = new DownloadPackageVersionStep(
new DownloadPackageVersionStep(SelectedPackage, installLocation, downloadOptions); SelectedPackage,
var installStep = new InstallPackageStep(SelectedPackage, SelectedTorchVersion, installLocation); installLocation,
var setupModelFoldersStep = new SetupModelFoldersStep(SelectedPackage, downloadOptions
SelectedSharedFolderMethod, installLocation); );
var installStep = new InstallPackageStep(
SelectedPackage,
SelectedTorchVersion,
installLocation
);
var setupModelFoldersStep = new SetupModelFoldersStep(
SelectedPackage,
SelectedSharedFolderMethod,
installLocation
);
var package = new InstalledPackage var package = new InstalledPackage
{ {
@ -210,7 +267,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
}; };
var addInstalledPackageStep = new AddInstalledPackageStep(settingsManager, package); var addInstalledPackageStep = new AddInstalledPackageStep(settingsManager, package);
var steps = new List<IPackageStep> var steps = new List<IPackageStep>
{ {
prereqStep, prereqStep,
@ -228,7 +285,7 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
{ {
OnCloseButtonClick(); OnCloseButtonClick();
} }
private void UpdateSelectedVersionToLatestMain() private void UpdateSelectedVersionToLatestMain()
{ {
if (AvailableVersions is null) if (AvailableVersions is null)
@ -241,34 +298,34 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
var version = AvailableVersions.FirstOrDefault(x => x.TagName == "master"); var version = AvailableVersions.FirstOrDefault(x => x.TagName == "master");
// If not found, try main // If not found, try main
version ??= AvailableVersions.FirstOrDefault(x => x.TagName == "main"); version ??= AvailableVersions.FirstOrDefault(x => x.TagName == "main");
// If still not found, just use the first one // If still not found, just use the first one
version ??= AvailableVersions[0]; version ??= AvailableVersions[0];
SelectedVersion = version; SelectedVersion = version;
} }
} }
[RelayCommand] [RelayCommand]
private async Task ShowPreview() private async Task ShowPreview()
{ {
var url = SelectedPackage.PreviewImageUri.ToString(); var url = SelectedPackage.PreviewImageUri.ToString();
var imageStream = await downloadService.GetImageStreamFromUrl(url); var imageStream = await downloadService.GetImageStreamFromUrl(url);
var bitmap = new Bitmap(imageStream); var bitmap = new Bitmap(imageStream);
var dialog = new ContentDialog var dialog = new ContentDialog
{ {
PrimaryButtonText = "Open in Browser", PrimaryButtonText = "Open in Browser",
CloseButtonText = "Close", CloseButtonText = "Close",
Content = new Image Content = new Image
{ {
Source = bitmap, Source = bitmap,
Stretch = Stretch.Uniform, Stretch = Stretch.Uniform,
MaxHeight = 500, MaxHeight = 500,
HorizontalAlignment = HorizontalAlignment.Center HorizontalAlignment = HorizontalAlignment.Center
} }
}; };
var result = await dialog.ShowAsync(); var result = await dialog.ShowAsync();
if (result == ContentDialogResult.Primary) if (result == ContentDialogResult.Primary)
{ {
@ -284,10 +341,11 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
SelectedVersionType = value; SelectedVersionType = value;
} }
} }
// When changing branch / release modes, refresh // When changing branch / release modes, refresh
// ReSharper disable once UnusedParameterInPartialMethod // ReSharper disable once UnusedParameterInPartialMethod
partial void OnSelectedVersionTypeChanged(PackageVersionType value) => OnSelectedPackageChanged(SelectedPackage); partial void OnSelectedVersionTypeChanged(PackageVersionType value) =>
OnSelectedPackageChanged(SelectedPackage);
partial void OnSelectedPackageChanged(BasePackage value) partial void OnSelectedPackageChanged(BasePackage value)
{ {
@ -302,74 +360,81 @@ public partial class InstallerViewModel : ContentDialogViewModelBase
SelectedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod; SelectedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod;
SelectedTorchVersion = SelectedPackage.GetRecommendedTorchVersion(); SelectedTorchVersion = SelectedPackage.GetRecommendedTorchVersion();
if (Design.IsDesignMode) return; if (Design.IsDesignMode)
return;
Dispatcher.UIThread.InvokeAsync(async () =>
{ Dispatcher.UIThread
Logger.Debug($"Release mode: {IsReleaseMode}"); .InvokeAsync(async () =>
var versionOptions = await value.GetAllVersionOptions();
AvailableVersions = IsReleaseMode
? new ObservableCollection<PackageVersion>(versionOptions.AvailableVersions)
: new ObservableCollection<PackageVersion>(versionOptions.AvailableBranches);
SelectedVersion = AvailableVersions.First(x => !x.IsPrerelease);
ReleaseNotes = SelectedVersion.ReleaseNotesMarkdown;
Logger.Debug($"Loaded release notes for {ReleaseNotes}");
if (!IsReleaseMode)
{ {
var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList(); Logger.Debug($"Release mode: {IsReleaseMode}");
if (commits is null || commits.Count == 0) return; var versionOptions = await value.GetAllVersionOptions();
AvailableCommits = new ObservableCollection<GitCommit>(commits);
SelectedCommit = AvailableCommits[0];
UpdateSelectedVersionToLatestMain();
}
InstallName = SelectedPackage.DisplayName; AvailableVersions = IsReleaseMode
LatestVersionText = IsReleaseMode ? new ObservableCollection<PackageVersion>(versionOptions.AvailableVersions)
? $"Latest version: {SelectedVersion.TagName}" : new ObservableCollection<PackageVersion>(versionOptions.AvailableBranches);
: $"Branch: {SelectedVersion.TagName}";
}).SafeFireAndForget(); SelectedVersion = AvailableVersions.First(x => !x.IsPrerelease);
ReleaseNotes = SelectedVersion.ReleaseNotesMarkdown;
Logger.Debug($"Loaded release notes for {ReleaseNotes}");
if (!IsReleaseMode)
{
var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList();
if (commits is null || commits.Count == 0)
return;
AvailableCommits = new ObservableCollection<GitCommit>(commits);
SelectedCommit = AvailableCommits[0];
UpdateSelectedVersionToLatestMain();
}
InstallName = SelectedPackage.DisplayName;
LatestVersionText = IsReleaseMode
? $"Latest version: {SelectedVersion.TagName}"
: $"Branch: {SelectedVersion.TagName}";
})
.SafeFireAndForget();
} }
partial void OnInstallNameChanged(string? value) partial void OnInstallNameChanged(string? value)
{ {
ShowDuplicateWarning = ShowDuplicateWarning = settingsManager.Settings.InstalledPackages.Any(
settingsManager.Settings.InstalledPackages.Any(p => p => p.LibraryPath == $"Packages{Path.DirectorySeparatorChar}{value}"
p.LibraryPath == $"Packages{Path.DirectorySeparatorChar}{value}"); );
} }
partial void OnSelectedVersionChanged(PackageVersion? value) partial void OnSelectedVersionChanged(PackageVersion? value)
{ {
ReleaseNotes = value?.ReleaseNotesMarkdown ?? string.Empty; ReleaseNotes = value?.ReleaseNotesMarkdown ?? string.Empty;
if (value == null) return; if (value == null)
return;
SelectedCommit = null; SelectedCommit = null;
AvailableCommits?.Clear(); AvailableCommits?.Clear();
if (!IsReleaseMode) if (!IsReleaseMode)
{ {
Task.Run(async () => Task.Run(async () =>
{
try
{ {
var hashes = await SelectedPackage.GetAllCommits(value.TagName); try
if (hashes is null) throw new Exception("No commits found");
Dispatcher.UIThread.Post(() =>
{ {
AvailableCommits = new ObservableCollection<GitCommit>(hashes); var hashes = await SelectedPackage.GetAllCommits(value.TagName);
SelectedCommit = AvailableCommits[0]; if (hashes is null)
}); throw new Exception("No commits found");
}
catch (Exception e) Dispatcher.UIThread.Post(() =>
{ {
Logger.Warn($"Error getting commits: {e.Message}"); AvailableCommits = new ObservableCollection<GitCommit>(hashes);
} SelectedCommit = AvailableCommits[0];
}).SafeFireAndForget(); });
}
catch (Exception e)
{
Logger.Warn($"Error getting commits: {e.Message}");
}
})
.SafeFireAndForget();
} }
} }
} }

80
StabilityMatrix.Avalonia/ViewModels/Dialogs/OneClickInstallViewModel.cs

@ -26,14 +26,27 @@ public partial class OneClickInstallViewModel : ViewModelBase
private readonly IPyRunner pyRunner; private readonly IPyRunner pyRunner;
private readonly ISharedFolders sharedFolders; private readonly ISharedFolders sharedFolders;
private const string DefaultPackageName = "stable-diffusion-webui"; private const string DefaultPackageName = "stable-diffusion-webui";
[ObservableProperty] private string headerText; [ObservableProperty]
[ObservableProperty] private string subHeaderText; private string headerText;
[ObservableProperty] private string subSubHeaderText = string.Empty;
[ObservableProperty] private bool showInstallButton; [ObservableProperty]
[ObservableProperty] private bool isIndeterminate; private string subHeaderText;
[ObservableProperty] private ObservableCollection<BasePackage> allPackages;
[ObservableProperty] private BasePackage selectedPackage; [ObservableProperty]
private string subSubHeaderText = string.Empty;
[ObservableProperty]
private bool showInstallButton;
[ObservableProperty]
private bool isIndeterminate;
[ObservableProperty]
private ObservableCollection<BasePackage> allPackages;
[ObservableProperty]
private BasePackage selectedPackage;
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsProgressBarVisible))] [NotifyPropertyChangedFor(nameof(IsProgressBarVisible))]
@ -41,9 +54,14 @@ public partial class OneClickInstallViewModel : ViewModelBase
public bool IsProgressBarVisible => OneClickInstallProgress > 0 || IsIndeterminate; public bool IsProgressBarVisible => OneClickInstallProgress > 0 || IsIndeterminate;
public OneClickInstallViewModel(ISettingsManager settingsManager, IPackageFactory packageFactory, public OneClickInstallViewModel(
IPrerequisiteHelper prerequisiteHelper, ILogger<OneClickInstallViewModel> logger, IPyRunner pyRunner, ISettingsManager settingsManager,
ISharedFolders sharedFolders) IPackageFactory packageFactory,
IPrerequisiteHelper prerequisiteHelper,
ILogger<OneClickInstallViewModel> logger,
IPyRunner pyRunner,
ISharedFolders sharedFolders
)
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.packageFactory = packageFactory; this.packageFactory = packageFactory;
@ -55,9 +73,9 @@ public partial class OneClickInstallViewModel : ViewModelBase
HeaderText = "Welcome to Stability Matrix!"; HeaderText = "Welcome to Stability Matrix!";
SubHeaderText = "Choose your preferred interface and click Install to get started!"; SubHeaderText = "Choose your preferred interface and click Install to get started!";
ShowInstallButton = true; ShowInstallButton = true;
AllPackages = AllPackages = new ObservableCollection<BasePackage>(
new ObservableCollection<BasePackage>(this.packageFactory.GetAllAvailablePackages() this.packageFactory.GetAllAvailablePackages().Where(p => p.OfferInOneClickInstaller)
.Where(p => p.OfferInOneClickInstaller)); );
SelectedPackage = AllPackages[0]; SelectedPackage = AllPackages[0];
} }
@ -75,7 +93,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
EventManager.Instance.OnOneClickInstallFinished(true); EventManager.Instance.OnOneClickInstallFinished(true);
return Task.CompletedTask; return Task.CompletedTask;
} }
private async Task DoInstall() private async Task DoInstall()
{ {
HeaderText = $"Installing {SelectedPackage.DisplayName}"; HeaderText = $"Installing {SelectedPackage.DisplayName}";
@ -83,11 +101,11 @@ public partial class OneClickInstallViewModel : ViewModelBase
var progressHandler = new Progress<ProgressReport>(progress => var progressHandler = new Progress<ProgressReport>(progress =>
{ {
SubHeaderText = $"{progress.Title} {progress.Percentage:N0}%"; SubHeaderText = $"{progress.Title} {progress.Percentage:N0}%";
IsIndeterminate = progress.IsIndeterminate; IsIndeterminate = progress.IsIndeterminate;
OneClickInstallProgress = Convert.ToInt32(progress.Percentage); OneClickInstallProgress = Convert.ToInt32(progress.Percentage);
}); });
await prerequisiteHelper.InstallAllIfNecessary(progressHandler); await prerequisiteHelper.InstallAllIfNecessary(progressHandler);
SubHeaderText = "Installing prerequisites..."; SubHeaderText = "Installing prerequisites...";
@ -104,7 +122,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
IsIndeterminate = false; IsIndeterminate = false;
var libraryDir = settingsManager.LibraryDir; var libraryDir = settingsManager.LibraryDir;
// get latest version & download & install // get latest version & download & install
SubHeaderText = "Getting latest version..."; SubHeaderText = "Getting latest version...";
var installLocation = Path.Combine(libraryDir, "Packages", SelectedPackage.Name); var installLocation = Path.Combine(libraryDir, "Packages", SelectedPackage.Name);
@ -112,12 +130,11 @@ public partial class OneClickInstallViewModel : ViewModelBase
var downloadVersion = new DownloadPackageVersionOptions(); var downloadVersion = new DownloadPackageVersionOptions();
var installedVersion = new InstalledPackageVersion(); var installedVersion = new InstalledPackageVersion();
var versionOptions = await SelectedPackage.GetAllVersionOptions(); var versionOptions = await SelectedPackage.GetAllVersionOptions();
if (versionOptions.AvailableVersions != null && versionOptions.AvailableVersions.Any()) if (versionOptions.AvailableVersions != null && versionOptions.AvailableVersions.Any())
{ {
downloadVersion.VersionTag = downloadVersion.VersionTag = versionOptions.AvailableVersions.First().TagName;
versionOptions.AvailableVersions.First().TagName;
installedVersion.InstalledReleaseVersion = downloadVersion.VersionTag; installedVersion.InstalledReleaseVersion = downloadVersion.VersionTag;
} }
else else
@ -125,16 +142,16 @@ public partial class OneClickInstallViewModel : ViewModelBase
downloadVersion.BranchName = await SelectedPackage.GetLatestVersion(); downloadVersion.BranchName = await SelectedPackage.GetLatestVersion();
installedVersion.InstalledBranch = downloadVersion.BranchName; installedVersion.InstalledBranch = downloadVersion.BranchName;
} }
var torchVersion = SelectedPackage.GetRecommendedTorchVersion(); var torchVersion = SelectedPackage.GetRecommendedTorchVersion();
await DownloadPackage(installLocation, downloadVersion); await DownloadPackage(installLocation, downloadVersion);
await InstallPackage(installLocation, torchVersion); await InstallPackage(installLocation, torchVersion);
SubHeaderText = "Setting up shared folder links..."; SubHeaderText = "Setting up shared folder links...";
var recommendedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod; var recommendedSharedFolderMethod = SelectedPackage.RecommendedSharedFolderMethod;
await SelectedPackage.SetupModelFolders(installLocation, recommendedSharedFolderMethod); await SelectedPackage.SetupModelFolders(installLocation, recommendedSharedFolderMethod);
var installedPackage = new InstalledPackage var installedPackage = new InstalledPackage
{ {
DisplayName = SelectedPackage.DisplayName, DisplayName = SelectedPackage.DisplayName,
@ -151,7 +168,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
st.Settings.InstalledPackages.Add(installedPackage); st.Settings.InstalledPackages.Add(installedPackage);
st.Settings.ActiveInstalledPackageId = installedPackage.Id; st.Settings.ActiveInstalledPackageId = installedPackage.Id;
EventManager.Instance.OnInstalledPackagesChanged(); EventManager.Instance.OnInstalledPackagesChanged();
HeaderText = "Installation complete!"; HeaderText = "Installation complete!";
SubSubHeaderText = string.Empty; SubSubHeaderText = string.Empty;
OneClickInstallProgress = 100; OneClickInstallProgress = 100;
@ -161,15 +178,18 @@ public partial class OneClickInstallViewModel : ViewModelBase
await Task.Delay(1000); await Task.Delay(1000);
SubHeaderText = "Proceeding to Launch page in 1 second..."; SubHeaderText = "Proceeding to Launch page in 1 second...";
await Task.Delay(1000); await Task.Delay(1000);
// should close dialog // should close dialog
EventManager.Instance.OnOneClickInstallFinished(false); EventManager.Instance.OnOneClickInstallFinished(false);
} }
private async Task DownloadPackage(string installLocation, DownloadPackageVersionOptions versionOptions) private async Task DownloadPackage(
string installLocation,
DownloadPackageVersionOptions versionOptions
)
{ {
SubHeaderText = "Downloading package..."; SubHeaderText = "Downloading package...";
var progress = new Progress<ProgressReport>(progress => var progress = new Progress<ProgressReport>(progress =>
{ {
IsIndeterminate = progress.IsIndeterminate; IsIndeterminate = progress.IsIndeterminate;
@ -186,7 +206,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
{ {
SelectedPackage.ConsoleOutput += (_, output) => SubSubHeaderText = output.Text; SelectedPackage.ConsoleOutput += (_, output) => SubSubHeaderText = output.Text;
SubHeaderText = "Downloading and installing package requirements..."; SubHeaderText = "Downloading and installing package requirements...";
var progress = new Progress<ProgressReport>(progress => var progress = new Progress<ProgressReport>(progress =>
{ {
SubHeaderText = "Downloading and installing package requirements..."; SubHeaderText = "Downloading and installing package requirements...";
@ -194,7 +214,7 @@ public partial class OneClickInstallViewModel : ViewModelBase
OneClickInstallProgress = Convert.ToInt32(progress.Percentage); OneClickInstallProgress = Convert.ToInt32(progress.Percentage);
EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress); EventManager.Instance.OnGlobalProgressChanged(OneClickInstallProgress);
}); });
await SelectedPackage.InstallPackage(installLocation, torchVersion, progress); await SelectedPackage.InstallPackage(installLocation, torchVersion, progress);
} }
} }

178
StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageImportViewModel.cs

@ -26,45 +26,58 @@ namespace StabilityMatrix.Avalonia.ViewModels.Dialogs;
public partial class PackageImportViewModel : ContentDialogViewModelBase public partial class PackageImportViewModel : ContentDialogViewModelBase
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly IPackageFactory packageFactory; private readonly IPackageFactory packageFactory;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
[ObservableProperty] private DirectoryPath? packagePath; [ObservableProperty]
[ObservableProperty] private BasePackage? selectedBasePackage; private DirectoryPath? packagePath;
public IReadOnlyList<BasePackage> AvailablePackages [ObservableProperty]
=> packageFactory.GetAllAvailablePackages().ToImmutableArray(); private BasePackage? selectedBasePackage;
[ObservableProperty] private PackageVersion? selectedVersion; public IReadOnlyList<BasePackage> AvailablePackages =>
packageFactory.GetAllAvailablePackages().ToImmutableArray();
[ObservableProperty] private ObservableCollection<GitCommit>? availableCommits;
[ObservableProperty] private ObservableCollection<PackageVersion>? availableVersions; [ObservableProperty]
private PackageVersion? selectedVersion;
[ObservableProperty] private GitCommit? selectedCommit;
[ObservableProperty]
private ObservableCollection<GitCommit>? availableCommits;
[ObservableProperty]
private ObservableCollection<PackageVersion>? availableVersions;
[ObservableProperty]
private GitCommit? selectedCommit;
// Version types (release or commit) // Version types (release or commit)
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(ReleaseLabelText), [NotifyPropertyChangedFor(
nameof(IsReleaseMode), nameof(SelectedVersion))] nameof(ReleaseLabelText),
nameof(IsReleaseMode),
nameof(SelectedVersion)
)]
private PackageVersionType selectedVersionType = PackageVersionType.Commit; private PackageVersionType selectedVersionType = PackageVersionType.Commit;
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(IsReleaseModeAvailable))] [NotifyPropertyChangedFor(nameof(IsReleaseModeAvailable))]
private PackageVersionType availableVersionTypes = private PackageVersionType availableVersionTypes =
PackageVersionType.GithubRelease | PackageVersionType.Commit; PackageVersionType.GithubRelease | PackageVersionType.Commit;
public string ReleaseLabelText => IsReleaseMode ? "Version" : "Branch"; public string ReleaseLabelText => IsReleaseMode ? "Version" : "Branch";
public bool IsReleaseMode public bool IsReleaseMode
{ {
get => SelectedVersionType == PackageVersionType.GithubRelease; get => SelectedVersionType == PackageVersionType.GithubRelease;
set => SelectedVersionType = value ? PackageVersionType.GithubRelease : PackageVersionType.Commit; set =>
SelectedVersionType = value
? PackageVersionType.GithubRelease
: PackageVersionType.Commit;
} }
public bool IsReleaseModeAvailable => AvailableVersionTypes.HasFlag(PackageVersionType.GithubRelease); public bool IsReleaseModeAvailable =>
AvailableVersionTypes.HasFlag(PackageVersionType.GithubRelease);
public PackageImportViewModel(
IPackageFactory packageFactory, public PackageImportViewModel(IPackageFactory packageFactory, ISettingsManager settingsManager)
ISettingsManager settingsManager)
{ {
this.packageFactory = packageFactory; this.packageFactory = packageFactory;
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
@ -73,24 +86,28 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
public override async Task OnLoadedAsync() public override async Task OnLoadedAsync()
{ {
SelectedBasePackage ??= AvailablePackages[0]; SelectedBasePackage ??= AvailablePackages[0];
if (Design.IsDesignMode) return; if (Design.IsDesignMode)
return;
// Populate available versions // Populate available versions
try try
{ {
var versionOptions = await SelectedBasePackage.GetAllVersionOptions(); var versionOptions = await SelectedBasePackage.GetAllVersionOptions();
if (IsReleaseMode) if (IsReleaseMode)
{ {
AvailableVersions = AvailableVersions = new ObservableCollection<PackageVersion>(
new ObservableCollection<PackageVersion>(versionOptions.AvailableVersions); versionOptions.AvailableVersions
if (!AvailableVersions.Any()) return; );
if (!AvailableVersions.Any())
return;
SelectedVersion = AvailableVersions[0]; SelectedVersion = AvailableVersions[0];
} }
else else
{ {
AvailableVersions = AvailableVersions = new ObservableCollection<PackageVersion>(
new ObservableCollection<PackageVersion>(versionOptions.AvailableBranches); versionOptions.AvailableBranches
);
UpdateSelectedVersionToLatestMain(); UpdateSelectedVersionToLatestMain();
} }
} }
@ -99,12 +116,12 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
Logger.Warn("Error getting versions: {Exception}", e.ToString()); Logger.Warn("Error getting versions: {Exception}", e.ToString());
} }
} }
private static string GetDisplayVersion(string version, string? branch) private static string GetDisplayVersion(string version, string? branch)
{ {
return branch == null ? version : $"{branch}@{version[..7]}"; return branch == null ? version : $"{branch}@{version[..7]}";
} }
// When available version types change, reset selected version type if not compatible // When available version types change, reset selected version type if not compatible
partial void OnAvailableVersionTypesChanged(PackageVersionType value) partial void OnAvailableVersionTypesChanged(PackageVersionType value)
{ {
@ -113,11 +130,11 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
SelectedVersionType = value; SelectedVersionType = value;
} }
} }
// When changing branch / release modes, refresh // When changing branch / release modes, refresh
// ReSharper disable once UnusedParameterInPartialMethod // ReSharper disable once UnusedParameterInPartialMethod
partial void OnSelectedVersionTypeChanged(PackageVersionType value) partial void OnSelectedVersionTypeChanged(PackageVersionType value) =>
=> OnSelectedBasePackageChanged(SelectedBasePackage); OnSelectedBasePackageChanged(SelectedBasePackage);
partial void OnSelectedBasePackageChanged(BasePackage? value) partial void OnSelectedBasePackageChanged(BasePackage? value)
{ {
@ -127,38 +144,42 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
AvailableCommits?.Clear(); AvailableCommits?.Clear();
return; return;
} }
AvailableVersions?.Clear(); AvailableVersions?.Clear();
AvailableCommits?.Clear(); AvailableCommits?.Clear();
AvailableVersionTypes = SelectedBasePackage.AvailableVersionTypes; AvailableVersionTypes = SelectedBasePackage.AvailableVersionTypes;
if (Design.IsDesignMode) return; if (Design.IsDesignMode)
return;
Dispatcher.UIThread.InvokeAsync(async () =>
{ Dispatcher.UIThread
Logger.Debug($"Release mode: {IsReleaseMode}"); .InvokeAsync(async () =>
var versionOptions = await value.GetAllVersionOptions();
AvailableVersions = IsReleaseModeAvailable
? new ObservableCollection<PackageVersion>(versionOptions.AvailableVersions)
: new ObservableCollection<PackageVersion>(versionOptions.AvailableBranches);
Logger.Debug($"Available versions: {string.Join(", ", AvailableVersions)}");
SelectedVersion = AvailableVersions[0];
if (!IsReleaseMode)
{ {
var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList(); Logger.Debug($"Release mode: {IsReleaseMode}");
if (commits is null || commits.Count == 0) return; var versionOptions = await value.GetAllVersionOptions();
AvailableCommits = new ObservableCollection<GitCommit>(commits); AvailableVersions = IsReleaseModeAvailable
SelectedCommit = AvailableCommits[0]; ? new ObservableCollection<PackageVersion>(versionOptions.AvailableVersions)
UpdateSelectedVersionToLatestMain(); : new ObservableCollection<PackageVersion>(versionOptions.AvailableBranches);
}
}).SafeFireAndForget(); Logger.Debug($"Available versions: {string.Join(", ", AvailableVersions)}");
SelectedVersion = AvailableVersions[0];
if (!IsReleaseMode)
{
var commits = (await value.GetAllCommits(SelectedVersion.TagName))?.ToList();
if (commits is null || commits.Count == 0)
return;
AvailableCommits = new ObservableCollection<GitCommit>(commits);
SelectedCommit = AvailableCommits[0];
UpdateSelectedVersionToLatestMain();
}
})
.SafeFireAndForget();
} }
private void UpdateSelectedVersionToLatestMain() private void UpdateSelectedVersionToLatestMain()
{ {
if (AvailableVersions is null) if (AvailableVersions is null)
@ -171,14 +192,14 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
var version = AvailableVersions.FirstOrDefault(x => x.TagName == "master"); var version = AvailableVersions.FirstOrDefault(x => x.TagName == "master");
// If not found, try main // If not found, try main
version ??= AvailableVersions.FirstOrDefault(x => x.TagName == "main"); version ??= AvailableVersions.FirstOrDefault(x => x.TagName == "main");
// If still not found, just use the first one // If still not found, just use the first one
version ??= AvailableVersions[0]; version ??= AvailableVersions[0];
SelectedVersion = version; SelectedVersion = version;
} }
} }
public async Task AddPackageWithCurrentInputs() public async Task AddPackageWithCurrentInputs()
{ {
if (SelectedBasePackage is null || PackagePath is null) if (SelectedBasePackage is null || PackagePath is null)
@ -187,20 +208,19 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
var version = new InstalledPackageVersion(); var version = new InstalledPackageVersion();
if (IsReleaseMode) if (IsReleaseMode)
{ {
version.InstalledReleaseVersion = SelectedVersion?.TagName ?? version.InstalledReleaseVersion =
throw new NullReferenceException( SelectedVersion?.TagName
"Selected version is null"); ?? throw new NullReferenceException("Selected version is null");
} }
else else
{ {
version.InstalledBranch = SelectedVersion?.TagName ?? version.InstalledBranch =
throw new NullReferenceException( SelectedVersion?.TagName
"Selected version is null"); ?? throw new NullReferenceException("Selected version is null");
version.InstalledCommitSha = SelectedCommit?.Sha ?? version.InstalledCommitSha =
throw new NullReferenceException( SelectedCommit?.Sha ?? throw new NullReferenceException("Selected commit is null");
"Selected commit is null");
} }
var torchVersion = SelectedBasePackage.GetRecommendedTorchVersion(); var torchVersion = SelectedBasePackage.GetRecommendedTorchVersion();
var sharedFolderRecommendation = SelectedBasePackage.RecommendedSharedFolderMethod; var sharedFolderRecommendation = SelectedBasePackage.RecommendedSharedFolderMethod;
var package = new InstalledPackage var package = new InstalledPackage
@ -215,17 +235,17 @@ public partial class PackageImportViewModel : ContentDialogViewModelBase
PreferredTorchVersion = torchVersion, PreferredTorchVersion = torchVersion,
PreferredSharedFolderMethod = sharedFolderRecommendation PreferredSharedFolderMethod = sharedFolderRecommendation
}; };
// Recreate venv if it's a BaseGitPackage // Recreate venv if it's a BaseGitPackage
if (SelectedBasePackage is BaseGitPackage gitPackage) if (SelectedBasePackage is BaseGitPackage gitPackage)
{ {
await gitPackage.SetupVenv(PackagePath, forceRecreate: true); await gitPackage.SetupVenv(PackagePath, forceRecreate: true);
} }
// Reconfigure shared links // Reconfigure shared links
var recommendedSharedFolderMethod = SelectedBasePackage.RecommendedSharedFolderMethod; var recommendedSharedFolderMethod = SelectedBasePackage.RecommendedSharedFolderMethod;
await SelectedBasePackage.UpdateModelFolders(PackagePath, recommendedSharedFolderMethod); await SelectedBasePackage.UpdateModelFolders(PackagePath, recommendedSharedFolderMethod);
settingsManager.Transaction(s => s.InstalledPackages.Add(package)); settingsManager.Transaction(s => s.InstalledPackages.Add(package));
} }
} }

24
StabilityMatrix.Avalonia/ViewModels/Dialogs/PackageModificationDialogViewModel.cs

@ -16,14 +16,17 @@ public class PackageModificationDialogViewModel : ContentDialogProgressViewModel
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private readonly IEnumerable<IPackageStep> steps; private readonly IEnumerable<IPackageStep> steps;
public PackageModificationDialogViewModel(IPackageModificationRunner packageModificationRunner, public PackageModificationDialogViewModel(
INotificationService notificationService, IEnumerable<IPackageStep> steps) IPackageModificationRunner packageModificationRunner,
INotificationService notificationService,
IEnumerable<IPackageStep> steps
)
{ {
this.packageModificationRunner = packageModificationRunner; this.packageModificationRunner = packageModificationRunner;
this.notificationService = notificationService; this.notificationService = notificationService;
this.steps = steps; this.steps = steps;
} }
public ConsoleViewModel Console { get; } = new(); public ConsoleViewModel Console { get; } = new();
public override async Task OnLoadedAsync() public override async Task OnLoadedAsync()
@ -34,9 +37,12 @@ public class PackageModificationDialogViewModel : ContentDialogProgressViewModel
packageModificationRunner.ProgressChanged += PackageModificationRunnerOnProgressChanged; packageModificationRunner.ProgressChanged += PackageModificationRunnerOnProgressChanged;
await packageModificationRunner.ExecuteSteps(steps.ToList()); await packageModificationRunner.ExecuteSteps(steps.ToList());
notificationService.Show("Package Install Completed", notificationService.Show(
"Package install completed successfully.", NotificationType.Success); "Package Install Completed",
"Package install completed successfully.",
NotificationType.Success
);
OnCloseButtonClick(); OnCloseButtonClick();
} }
} }
@ -46,14 +52,14 @@ public class PackageModificationDialogViewModel : ContentDialogProgressViewModel
Text = string.IsNullOrWhiteSpace(e.Title) Text = string.IsNullOrWhiteSpace(e.Title)
? packageModificationRunner.CurrentStep?.ProgressTitle ? packageModificationRunner.CurrentStep?.ProgressTitle
: e.Title; : e.Title;
Value = e.Percentage; Value = e.Percentage;
Description = e.Message; Description = e.Message;
IsIndeterminate = e.IsIndeterminate; IsIndeterminate = e.IsIndeterminate;
if (string.IsNullOrWhiteSpace(e.Message) || e.Message.Equals("Downloading...")) if (string.IsNullOrWhiteSpace(e.Message) || e.Message.Equals("Downloading..."))
return; return;
Console.PostLine(e.Message); Console.PostLine(e.Message);
EventManager.Instance.OnScrollToBottomRequested(); EventManager.Instance.OnScrollToBottomRequested();
} }

93
StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectDataDirectoryViewModel.cs

@ -24,33 +24,43 @@ public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
public static string DefaultInstallLocation => Compat.AppDataHome; public static string DefaultInstallLocation => Compat.AppDataHome;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private const string ValidExistingDirectoryText = "Valid existing data directory found"; private const string ValidExistingDirectoryText = "Valid existing data directory found";
private const string InvalidDirectoryText = private const string InvalidDirectoryText =
"Directory must be empty or have a valid settings.json file"; "Directory must be empty or have a valid settings.json file";
private const string NotEnoughFreeSpaceText = "Not enough free space on the selected drive"; private const string NotEnoughFreeSpaceText = "Not enough free space on the selected drive";
private const string FatWarningText = private const string FatWarningText = "FAT32 / exFAT drives are not supported at this time";
"FAT32 / exFAT drives are not supported at this time";
[ObservableProperty]
[ObservableProperty] private string dataDirectory = DefaultInstallLocation; private string dataDirectory = DefaultInstallLocation;
[ObservableProperty] private bool isPortableMode;
[ObservableProperty]
[ObservableProperty] private string directoryStatusText = string.Empty; private bool isPortableMode;
[ObservableProperty] private bool isStatusBadgeVisible;
[ObservableProperty] private bool isDirectoryValid; [ObservableProperty]
[ObservableProperty] private bool showFatWarning; private string directoryStatusText = string.Empty;
public RefreshBadgeViewModel ValidatorRefreshBadge { get; } = new() [ObservableProperty]
{ private bool isStatusBadgeVisible;
State = ProgressState.Inactive,
SuccessToolTipText = ValidExistingDirectoryText, [ObservableProperty]
FailToolTipText = InvalidDirectoryText private bool isDirectoryValid;
};
[ObservableProperty]
private bool showFatWarning;
public RefreshBadgeViewModel ValidatorRefreshBadge { get; } =
new()
{
State = ProgressState.Inactive,
SuccessToolTipText = ValidExistingDirectoryText,
FailToolTipText = InvalidDirectoryText
};
public bool HasOldData => settingsManager.GetOldInstalledPackages().Any(); public bool HasOldData => settingsManager.GetOldInstalledPackages().Any();
public SelectDataDirectoryViewModel(ISettingsManager settingsManager) public SelectDataDirectoryViewModel(ISettingsManager settingsManager)
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
@ -61,17 +71,17 @@ public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
{ {
ValidatorRefreshBadge.RefreshCommand.ExecuteAsync(null).SafeFireAndForget(); ValidatorRefreshBadge.RefreshCommand.ExecuteAsync(null).SafeFireAndForget();
} }
// Revalidate on data directory change // Revalidate on data directory change
partial void OnDataDirectoryChanged(string value) partial void OnDataDirectoryChanged(string value)
{ {
ValidatorRefreshBadge.RefreshCommand.ExecuteAsync(null).SafeFireAndForget(); ValidatorRefreshBadge.RefreshCommand.ExecuteAsync(null).SafeFireAndForget();
} }
private async Task<bool> ValidateDataDirectory() private async Task<bool> ValidateDataDirectory()
{ {
await using var delay = new MinimumDelay(100, 200); await using var delay = new MinimumDelay(100, 200);
ShowFatWarning = IsDriveFat(DataDirectory); ShowFatWarning = IsDriveFat(DataDirectory);
// Doesn't exist, this is fine as a new install, hide badge // Doesn't exist, this is fine as a new install, hide badge
@ -83,17 +93,17 @@ public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
} }
// Otherwise check that a settings.json exists // Otherwise check that a settings.json exists
var settingsPath = Path.Combine(DataDirectory, "settings.json"); var settingsPath = Path.Combine(DataDirectory, "settings.json");
// settings.json exists: Try deserializing it // settings.json exists: Try deserializing it
if (File.Exists(settingsPath)) if (File.Exists(settingsPath))
{ {
try try
{ {
var jsonText = await File.ReadAllTextAsync(settingsPath); var jsonText = await File.ReadAllTextAsync(settingsPath);
JsonSerializer.Deserialize<Settings>(jsonText, new JsonSerializerOptions JsonSerializer.Deserialize<Settings>(
{ jsonText,
Converters = { new JsonStringEnumConverter() } new JsonSerializerOptions { Converters = { new JsonStringEnumConverter() } }
}); );
// If successful, show existing badge // If successful, show existing badge
IsStatusBadgeVisible = true; IsStatusBadgeVisible = true;
IsDirectoryValid = true; IsDirectoryValid = true;
@ -110,9 +120,9 @@ public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
return false; return false;
} }
} }
// No settings.json // No settings.json
// Check if the directory is %APPDATA%\StabilityMatrix: hide badge and set directory valid // Check if the directory is %APPDATA%\StabilityMatrix: hide badge and set directory valid
if (DataDirectory == DefaultInstallLocation) if (DataDirectory == DefaultInstallLocation)
{ {
@ -120,7 +130,7 @@ public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
IsDirectoryValid = true; IsDirectoryValid = true;
return true; return true;
} }
// Check if the directory is empty: hide badge and set directory to valid // Check if the directory is empty: hide badge and set directory to valid
var isEmpty = !Directory.EnumerateFileSystemEntries(DataDirectory).Any(); var isEmpty = !Directory.EnumerateFileSystemEntries(DataDirectory).Any();
if (isEmpty) if (isEmpty)
@ -136,21 +146,20 @@ public partial class SelectDataDirectoryViewModel : ContentDialogViewModelBase
DirectoryStatusText = InvalidDirectoryText; DirectoryStatusText = InvalidDirectoryText;
return false; return false;
} }
private bool CanPickFolder => App.StorageProvider.CanPickFolder; private bool CanPickFolder => App.StorageProvider.CanPickFolder;
[RelayCommand(CanExecute = nameof(CanPickFolder))] [RelayCommand(CanExecute = nameof(CanPickFolder))]
private async Task ShowFolderBrowserDialog() private async Task ShowFolderBrowserDialog()
{ {
var provider = App.StorageProvider; var provider = App.StorageProvider;
var result = await provider.OpenFolderPickerAsync(new FolderPickerOpenOptions var result = await provider.OpenFolderPickerAsync(
{ new FolderPickerOpenOptions { Title = "Select Data Folder", AllowMultiple = false }
Title = "Select Data Folder", );
AllowMultiple = false
}); if (result.Count != 1)
return;
if (result.Count != 1) return;
DataDirectory = result[0].Path.LocalPath; DataDirectory = result[0].Path.LocalPath;
} }

49
StabilityMatrix.Avalonia/ViewModels/Dialogs/SelectModelVersionViewModel.cs

@ -24,22 +24,37 @@ public partial class SelectModelVersionViewModel : ContentDialogViewModelBase
public required string Description { get; set; } public required string Description { get; set; }
public required string Title { get; set; } public required string Title { get; set; }
[ObservableProperty] private Bitmap? previewImage; [ObservableProperty]
[ObservableProperty] private ModelVersionViewModel? selectedVersionViewModel; private Bitmap? previewImage;
[ObservableProperty] private CivitFileViewModel? selectedFile;
[ObservableProperty] private bool isImportEnabled; [ObservableProperty]
[ObservableProperty] private ObservableCollection<ImageSource> imageUrls = new(); private ModelVersionViewModel? selectedVersionViewModel;
[ObservableProperty] private bool canGoToNextImage;
[ObservableProperty] private bool canGoToPreviousImage; [ObservableProperty]
private CivitFileViewModel? selectedFile;
[ObservableProperty]
private bool isImportEnabled;
[ObservableProperty]
private ObservableCollection<ImageSource> imageUrls = new();
[ObservableProperty]
private bool canGoToNextImage;
[ObservableProperty]
private bool canGoToPreviousImage;
[ObservableProperty] [ObservableProperty]
[NotifyPropertyChangedFor(nameof(DisplayedPageNumber))] [NotifyPropertyChangedFor(nameof(DisplayedPageNumber))]
private int selectedImageIndex; private int selectedImageIndex;
public int DisplayedPageNumber => SelectedImageIndex + 1; public int DisplayedPageNumber => SelectedImageIndex + 1;
public SelectModelVersionViewModel(ISettingsManager settingsManager, public SelectModelVersionViewModel(
IDownloadService downloadService) ISettingsManager settingsManager,
IDownloadService downloadService
)
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.downloadService = downloadService; this.downloadService = downloadService;
@ -54,12 +69,14 @@ public partial class SelectModelVersionViewModel : ContentDialogViewModelBase
partial void OnSelectedVersionViewModelChanged(ModelVersionViewModel? value) partial void OnSelectedVersionViewModelChanged(ModelVersionViewModel? value)
{ {
var nsfwEnabled = settingsManager.Settings.ModelBrowserNsfwEnabled; var nsfwEnabled = settingsManager.Settings.ModelBrowserNsfwEnabled;
var allImages = value?.ModelVersion?.Images?.Where( var allImages = value
img => nsfwEnabled || img.Nsfw == "None")?.Select(x => new ImageSource(x.Url)).ToList(); ?.ModelVersion?.Images?.Where(img => nsfwEnabled || img.Nsfw == "None")
?.Select(x => new ImageSource(x.Url))
.ToList();
if (allImages == null || !allImages.Any()) if (allImages == null || !allImages.Any())
{ {
allImages = new List<ImageSource> {new(Assets.NoImage)}; allImages = new List<ImageSource> { new(Assets.NoImage) };
CanGoToNextImage = false; CanGoToNextImage = false;
} }
else else
@ -93,14 +110,16 @@ public partial class SelectModelVersionViewModel : ContentDialogViewModelBase
public void PreviousImage() public void PreviousImage()
{ {
if (SelectedImageIndex > 0) SelectedImageIndex--; if (SelectedImageIndex > 0)
SelectedImageIndex--;
CanGoToPreviousImage = SelectedImageIndex > 0; CanGoToPreviousImage = SelectedImageIndex > 0;
CanGoToNextImage = SelectedImageIndex < ImageUrls.Count - 1; CanGoToNextImage = SelectedImageIndex < ImageUrls.Count - 1;
} }
public void NextImage() public void NextImage()
{ {
if (SelectedImageIndex < ImageUrls.Count - 1) SelectedImageIndex++; if (SelectedImageIndex < ImageUrls.Count - 1)
SelectedImageIndex++;
CanGoToPreviousImage = SelectedImageIndex > 0; CanGoToPreviousImage = SelectedImageIndex > 0;
CanGoToNextImage = SelectedImageIndex < ImageUrls.Count - 1; CanGoToNextImage = SelectedImageIndex < ImageUrls.Count - 1;
} }

64
StabilityMatrix.Avalonia/ViewModels/Dialogs/UpdateViewModel.cs

@ -24,18 +24,29 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
private readonly IHttpClientFactory httpClientFactory; private readonly IHttpClientFactory httpClientFactory;
private readonly IUpdateHelper updateHelper; private readonly IUpdateHelper updateHelper;
[ObservableProperty] private bool isUpdateAvailable; [ObservableProperty]
[ObservableProperty] private UpdateInfo? updateInfo; private bool isUpdateAvailable;
[ObservableProperty] private string? releaseNotes; [ObservableProperty]
[ObservableProperty] private string? updateText; private UpdateInfo? updateInfo;
[ObservableProperty] private int progressValue;
[ObservableProperty] private bool showProgressBar; [ObservableProperty]
private string? releaseNotes;
[ObservableProperty]
private string? updateText;
[ObservableProperty]
private int progressValue;
[ObservableProperty]
private bool showProgressBar;
public UpdateViewModel( public UpdateViewModel(
ISettingsManager settingsManager, ISettingsManager settingsManager,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
IUpdateHelper updateHelper) IUpdateHelper updateHelper
)
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.httpClientFactory = httpClientFactory; this.httpClientFactory = httpClientFactory;
@ -48,19 +59,21 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
}; };
updateHelper.StartCheckingForUpdates().SafeFireAndForget(); updateHelper.StartCheckingForUpdates().SafeFireAndForget();
} }
public override async Task OnLoadedAsync() public override async Task OnLoadedAsync()
{ {
if (UpdateInfo is null) return; if (UpdateInfo is null)
return;
UpdateText = $"Stability Matrix v{UpdateInfo.Version} is now available! You currently have v{Compat.AppVersion}. Would you like to update now?";
UpdateText =
$"Stability Matrix v{UpdateInfo.Version} is now available! You currently have v{Compat.AppVersion}. Would you like to update now?";
var client = httpClientFactory.CreateClient(); var client = httpClientFactory.CreateClient();
var response = await client.GetAsync(UpdateInfo.ChangelogUrl); var response = await client.GetAsync(UpdateInfo.ChangelogUrl);
if (response.IsSuccessStatusCode) if (response.IsSuccessStatusCode)
{ {
ReleaseNotes = await response.Content.ReadAsStringAsync(); ReleaseNotes = await response.Content.ReadAsStringAsync();
// Formatting for new changelog format // Formatting for new changelog format
// https://keepachangelog.com/en/1.1.0/ // https://keepachangelog.com/en/1.1.0/
if (UpdateInfo.ChangelogUrl.EndsWith(".md", StringComparison.OrdinalIgnoreCase)) if (UpdateInfo.ChangelogUrl.EndsWith(".md", StringComparison.OrdinalIgnoreCase))
@ -87,20 +100,23 @@ public partial class UpdateViewModel : ContentDialogViewModelBase
{ {
return; return;
} }
ShowProgressBar = true; ShowProgressBar = true;
UpdateText = $"Downloading update v{UpdateInfo.Version}..."; UpdateText = $"Downloading update v{UpdateInfo.Version}...";
await updateHelper.DownloadUpdate(UpdateInfo, new Progress<ProgressReport>(report => await updateHelper.DownloadUpdate(
{ UpdateInfo,
ProgressValue = Convert.ToInt32(report.Percentage); new Progress<ProgressReport>(report =>
})); {
ProgressValue = Convert.ToInt32(report.Percentage);
})
);
// On unix, we need to set the executable bit // On unix, we need to set the executable bit
if (Compat.IsUnix) if (Compat.IsUnix)
{ {
File.SetUnixFileMode(UpdateHelper.ExecutablePath, (UnixFileMode) 0x755); File.SetUnixFileMode(UpdateHelper.ExecutablePath, (UnixFileMode)0x755);
} }
UpdateText = "Update complete. Restarting Stability Matrix in 3 seconds..."; UpdateText = "Update complete. Restarting Stability Matrix in 3 seconds...";
await Task.Delay(1000); await Task.Delay(1000);
UpdateText = "Update complete. Restarting Stability Matrix in 2 seconds..."; UpdateText = "Update complete. Restarting Stability Matrix in 2 seconds...";

329
StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs

@ -48,52 +48,68 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
private readonly ISharedFolders sharedFolders; private readonly ISharedFolders sharedFolders;
private readonly ServiceManager<ViewModelBase> dialogFactory; private readonly ServiceManager<ViewModelBase> dialogFactory;
protected readonly IPackageFactory PackageFactory; protected readonly IPackageFactory PackageFactory;
// Regex to match if input contains a yes/no prompt, // Regex to match if input contains a yes/no prompt,
// i.e "Y/n", "yes/no". Case insensitive. // i.e "Y/n", "yes/no". Case insensitive.
// Separated by / or |. // Separated by / or |.
[GeneratedRegex(@"y(/|\|)n|yes(/|\|)no", RegexOptions.IgnoreCase)] [GeneratedRegex(@"y(/|\|)n|yes(/|\|)no", RegexOptions.IgnoreCase)]
private static partial Regex InputYesNoRegex(); private static partial Regex InputYesNoRegex();
public override string Title => "Launch"; public override string Title => "Launch";
public override IconSource IconSource => new SymbolIconSource { Symbol = Symbol.Rocket, IsFilled = true}; public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.Rocket, IsFilled = true };
public ConsoleViewModel Console { get; } = new(); public ConsoleViewModel Console { get; } = new();
[ObservableProperty] private bool launchButtonVisibility; [ObservableProperty]
[ObservableProperty] private bool stopButtonVisibility; private bool launchButtonVisibility;
[ObservableProperty] private bool isLaunchTeachingTipsOpen;
[ObservableProperty] private bool showWebUiButton; [ObservableProperty]
private bool stopButtonVisibility;
[ObservableProperty, NotifyPropertyChangedFor(nameof(SelectedBasePackage),
nameof(SelectedPackageExtraCommands))] [ObservableProperty]
private bool isLaunchTeachingTipsOpen;
[ObservableProperty]
private bool showWebUiButton;
[
ObservableProperty,
NotifyPropertyChangedFor(nameof(SelectedBasePackage), nameof(SelectedPackageExtraCommands))
]
private InstalledPackage? selectedPackage; private InstalledPackage? selectedPackage;
[ObservableProperty] private ObservableCollection<InstalledPackage> installedPackages = new();
[ObservableProperty] private BasePackage? runningPackage; [ObservableProperty]
private ObservableCollection<InstalledPackage> installedPackages = new();
[ObservableProperty]
private BasePackage? runningPackage;
public virtual BasePackage? SelectedBasePackage => public virtual BasePackage? SelectedBasePackage =>
PackageFactory.FindPackageByName(SelectedPackage?.PackageName); PackageFactory.FindPackageByName(SelectedPackage?.PackageName);
public IEnumerable<string> SelectedPackageExtraCommands => public IEnumerable<string> SelectedPackageExtraCommands =>
SelectedBasePackage?.ExtraLaunchCommands ?? Enumerable.Empty<string>(); SelectedBasePackage?.ExtraLaunchCommands ?? Enumerable.Empty<string>();
// private bool clearingPackages; // private bool clearingPackages;
private string webUiUrl = string.Empty; private string webUiUrl = string.Empty;
// Input info-bars // Input info-bars
[ObservableProperty] private bool showManualInputPrompt; [ObservableProperty]
[ObservableProperty] private bool showConfirmInputPrompt; private bool showManualInputPrompt;
[ObservableProperty]
private bool showConfirmInputPrompt;
public LaunchPageViewModel( public LaunchPageViewModel(
ILogger<LaunchPageViewModel> logger, ILogger<LaunchPageViewModel> logger,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IPackageFactory packageFactory, IPackageFactory packageFactory,
IPyRunner pyRunner, IPyRunner pyRunner,
INotificationService notificationService, INotificationService notificationService,
ISharedFolders sharedFolders, ISharedFolders sharedFolders,
ServiceManager<ViewModelBase> dialogFactory) ServiceManager<ViewModelBase> dialogFactory
)
{ {
this.logger = logger; this.logger = logger;
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
@ -102,11 +118,13 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
this.notificationService = notificationService; this.notificationService = notificationService;
this.sharedFolders = sharedFolders; this.sharedFolders = sharedFolders;
this.dialogFactory = dialogFactory; this.dialogFactory = dialogFactory;
settingsManager.RelayPropertyFor(this, settingsManager.RelayPropertyFor(
this,
vm => vm.SelectedPackage, vm => vm.SelectedPackage,
settings => settings.ActiveInstalledPackage); settings => settings.ActiveInstalledPackage
);
EventManager.Instance.PackageLaunchRequested += OnPackageLaunchRequested; EventManager.Instance.PackageLaunchRequested += OnPackageLaunchRequested;
EventManager.Instance.OneClickInstallFinished += OnOneClickInstallFinished; EventManager.Instance.OneClickInstallFinished += OnOneClickInstallFinished;
EventManager.Instance.InstalledPackagesChanged += OnInstalledPackagesChanged; EventManager.Instance.InstalledPackagesChanged += OnInstalledPackagesChanged;
@ -130,9 +148,11 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
{ {
if (RunningPackage is not null) if (RunningPackage is not null)
{ {
notificationService.Show("A package is already running", notificationService.Show(
"A package is already running",
"Please stop the current package before launching another.", "Please stop the current package before launching another.",
NotificationType.Error); NotificationType.Error
);
return; return;
} }
@ -143,15 +163,19 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
public override void OnLoaded() public override void OnLoaded()
{ {
// Ensure active package either exists or is null // Ensure active package either exists or is null
settingsManager.Transaction(s => settingsManager.Transaction(
{ s =>
s.UpdateActiveInstalledPackage(); {
}, ignoreMissingLibraryDir: true); s.UpdateActiveInstalledPackage();
},
ignoreMissingLibraryDir: true
);
// Load installed packages // Load installed packages
InstalledPackages = InstalledPackages = new ObservableCollection<InstalledPackage>(
new ObservableCollection<InstalledPackage>(settingsManager.Settings.InstalledPackages); settingsManager.Settings.InstalledPackages
);
// Load active package // Load active package
SelectedPackage = settingsManager.Settings.ActiveInstalledPackage; SelectedPackage = settingsManager.Settings.ActiveInstalledPackage;
} }
@ -165,15 +189,18 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
protected virtual async Task LaunchImpl(string? command) protected virtual async Task LaunchImpl(string? command)
{ {
IsLaunchTeachingTipsOpen = false; IsLaunchTeachingTipsOpen = false;
var activeInstall = SelectedPackage; var activeInstall = SelectedPackage;
if (activeInstall == null) if (activeInstall == null)
{ {
// No selected package: error notification // No selected package: error notification
notificationService.Show(new Notification( notificationService.Show(
message: "You must install and select a package before launching", new Notification(
title: "No package selected", message: "You must install and select a package before launching",
type: NotificationType.Error)); title: "No package selected",
type: NotificationType.Error
)
);
return; return;
} }
@ -186,11 +213,16 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
{ {
logger.LogWarning( logger.LogWarning(
"During launch, package name '{PackageName}' did not match a definition", "During launch, package name '{PackageName}' did not match a definition",
activeInstallName); activeInstallName
);
notificationService.Show(new Notification("Package name invalid",
"Install package name did not match a definition. Please reinstall and let us know about this issue.", notificationService.Show(
NotificationType.Error)); new Notification(
"Package name invalid",
"Install package name did not match a definition. Please reinstall and let us know about this issue.",
NotificationType.Error
)
);
return; return;
} }
@ -205,13 +237,13 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
.FromDefinitions(definitions, Array.Empty<LaunchOption>()) .FromDefinitions(definitions, Array.Empty<LaunchOption>())
.ToImmutableArray(); .ToImmutableArray();
var args = cards var args = cards.SelectMany(c => c.Options).ToList();
.SelectMany(c => c.Options)
.ToList(); logger.LogDebug(
"Setting initial launch args: {Args}",
logger.LogDebug("Setting initial launch args: {Args}", string.Join(", ", args.Select(o => o.ToArgString()?.ToRepr()))
string.Join(", ", args.Select(o => o.ToArgString()?.ToRepr()))); );
settingsManager.SaveLaunchArgs(activeInstall.Id, args); settingsManager.SaveLaunchArgs(activeInstall.Id, args);
} }
@ -219,22 +251,24 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
// Get path from package // Get path from package
var packagePath = new DirectoryPath(settingsManager.LibraryDir, activeInstall.LibraryPath!); var packagePath = new DirectoryPath(settingsManager.LibraryDir, activeInstall.LibraryPath!);
// Unpack sitecustomize.py to venv // Unpack sitecustomize.py to venv
await UnpackSiteCustomize(packagePath.JoinDir("venv")); await UnpackSiteCustomize(packagePath.JoinDir("venv"));
basePackage.ConsoleOutput += OnProcessOutputReceived; basePackage.ConsoleOutput += OnProcessOutputReceived;
basePackage.Exited += OnProcessExited; basePackage.Exited += OnProcessExited;
basePackage.StartupComplete += RunningPackageOnStartupComplete; basePackage.StartupComplete += RunningPackageOnStartupComplete;
// Clear console and start update processing // Clear console and start update processing
await Console.StopUpdatesAsync(); await Console.StopUpdatesAsync();
await Console.Clear(); await Console.Clear();
Console.StartUpdates(); Console.StartUpdates();
// Update shared folder links (in case library paths changed) // Update shared folder links (in case library paths changed)
await basePackage.UpdateModelFolders(packagePath, await basePackage.UpdateModelFolders(
activeInstall.PreferredSharedFolderMethod ?? basePackage.RecommendedSharedFolderMethod); packagePath,
activeInstall.PreferredSharedFolderMethod ?? basePackage.RecommendedSharedFolderMethod
);
// Load user launch args from settings and convert to string // Load user launch args from settings and convert to string
var userArgs = settingsManager.GetLaunchArgs(activeInstall.Id); var userArgs = settingsManager.GetLaunchArgs(activeInstall.Id);
@ -242,14 +276,16 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
// Join with extras, if any // Join with extras, if any
userArgsString = string.Join(" ", userArgsString, basePackage.ExtraLaunchArguments); userArgsString = string.Join(" ", userArgsString, basePackage.ExtraLaunchArguments);
// Use input command if provided, otherwise use package launch command // Use input command if provided, otherwise use package launch command
command ??= basePackage.LaunchCommand; command ??= basePackage.LaunchCommand;
await basePackage.RunPackage(packagePath, command, userArgsString); await basePackage.RunPackage(packagePath, command, userArgsString);
RunningPackage = basePackage; RunningPackage = basePackage;
EventManager.Instance.OnRunningPackageStatusChanged(new PackagePair(activeInstall, basePackage)); EventManager.Instance.OnRunningPackageStatusChanged(
new PackagePair(activeInstall, basePackage)
);
} }
// Unpacks sitecustomize.py to the target venv // Unpacks sitecustomize.py to the target venv
@ -293,12 +329,12 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
// Open a config page // Open a config page
var userLaunchArgs = settingsManager.GetLaunchArgs(activeInstall.Id); var userLaunchArgs = settingsManager.GetLaunchArgs(activeInstall.Id);
var viewModel = dialogFactory.Get<LaunchOptionsViewModel>(); var viewModel = dialogFactory.Get<LaunchOptionsViewModel>();
viewModel.Cards = LaunchOptionCard.FromDefinitions(definitions, userLaunchArgs) viewModel.Cards = LaunchOptionCard
.FromDefinitions(definitions, userLaunchArgs)
.ToImmutableArray(); .ToImmutableArray();
logger.LogDebug("Launching config dialog with cards: {CardsCount}", logger.LogDebug("Launching config dialog with cards: {CardsCount}", viewModel.Cards.Count);
viewModel.Cards.Count);
var dialog = new BetterContentDialog var dialog = new BetterContentDialog
{ {
ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled, ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled,
@ -309,12 +345,9 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
DefaultButton = ContentDialogButton.Primary, DefaultButton = ContentDialogButton.Primary,
ContentMargin = new Thickness(32, 16), ContentMargin = new Thickness(32, 16),
Padding = new Thickness(0, 16), Padding = new Thickness(0, 16),
Content = new LaunchOptionsDialog Content = new LaunchOptionsDialog { DataContext = viewModel, }
{
DataContext = viewModel,
}
}; };
var result = await dialog.ShowAsync(); var result = await dialog.ShowAsync();
if (result == ContentDialogResult.Primary) if (result == ContentDialogResult.Primary)
@ -324,7 +357,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
settingsManager.SaveLaunchArgs(activeInstall.Id, args); settingsManager.SaveLaunchArgs(activeInstall.Id, args);
} }
} }
// Send user input to running package // Send user input to running package
public async Task SendInput(string input) public async Task SendInput(string input)
{ {
@ -362,7 +395,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
ShowConfirmInputPrompt = false; ShowConfirmInputPrompt = false;
} }
[RelayCommand] [RelayCommand]
private async Task SendManualInput(string input) private async Task SendManualInput(string input)
{ {
@ -370,71 +403,85 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
Console.PostLine(input); Console.PostLine(input);
await SendInput(input); await SendInput(input);
} }
public virtual async Task Stop() public virtual async Task Stop()
{ {
if (RunningPackage is null) return; if (RunningPackage is null)
return;
await RunningPackage.WaitForShutdown(); await RunningPackage.WaitForShutdown();
RunningPackage = null; RunningPackage = null;
ShowWebUiButton = false; ShowWebUiButton = false;
Console.PostLine($"{Environment.NewLine}Stopped process at {DateTimeOffset.Now}"); Console.PostLine($"{Environment.NewLine}Stopped process at {DateTimeOffset.Now}");
} }
public void OpenWebUi() public void OpenWebUi()
{ {
if (string.IsNullOrEmpty(webUiUrl)) return; if (string.IsNullOrEmpty(webUiUrl))
return;
notificationService.TryAsync(Task.Run(() => ProcessRunner.OpenUrl(webUiUrl)),
"Failed to open URL", $"{webUiUrl}"); notificationService.TryAsync(
Task.Run(() => ProcessRunner.OpenUrl(webUiUrl)),
"Failed to open URL",
$"{webUiUrl}"
);
} }
private void OnProcessExited(object? sender, int exitCode) private void OnProcessExited(object? sender, int exitCode)
{ {
EventManager.Instance.OnRunningPackageStatusChanged(null); EventManager.Instance.OnRunningPackageStatusChanged(null);
Dispatcher.UIThread.InvokeAsync(async () => Dispatcher.UIThread
{ .InvokeAsync(async () =>
logger.LogTrace("Process exited ({Code}) at {Time:g}",
exitCode, DateTimeOffset.Now);
// Need to wait for streams to finish before detaching handlers
if (sender is BaseGitPackage {VenvRunner: not null} package)
{ {
var process = package.VenvRunner.Process; logger.LogTrace(
if (process is not null) "Process exited ({Code}) at {Time:g}",
exitCode,
DateTimeOffset.Now
);
// Need to wait for streams to finish before detaching handlers
if (sender is BaseGitPackage { VenvRunner: not null } package)
{ {
// Max 5 seconds var process = package.VenvRunner.Process;
var ct = new CancellationTokenSource(5000).Token; if (process is not null)
try
{ {
await process.WaitUntilOutputEOF(ct); // Max 5 seconds
} var ct = new CancellationTokenSource(5000).Token;
catch (OperationCanceledException e) try
{ {
logger.LogWarning("Waiting for process EOF timed out: {Message}", e.Message); await process.WaitUntilOutputEOF(ct);
}
catch (OperationCanceledException e)
{
logger.LogWarning(
"Waiting for process EOF timed out: {Message}",
e.Message
);
}
} }
} }
}
// Detach handlers
// Detach handlers if (sender is BasePackage basePackage)
if (sender is BasePackage basePackage) {
{ basePackage.ConsoleOutput -= OnProcessOutputReceived;
basePackage.ConsoleOutput -= OnProcessOutputReceived; basePackage.Exited -= OnProcessExited;
basePackage.Exited -= OnProcessExited; basePackage.StartupComplete -= RunningPackageOnStartupComplete;
basePackage.StartupComplete -= RunningPackageOnStartupComplete; }
} RunningPackage = null;
RunningPackage = null; ShowWebUiButton = false;
ShowWebUiButton = false;
await Console.StopUpdatesAsync();
await Console.StopUpdatesAsync();
// Need to reset cursor in case its in some weird position
// Need to reset cursor in case its in some weird position // from progress bars
// from progress bars await Console.ResetWriteCursor();
await Console.ResetWriteCursor(); Console.PostLine(
Console.PostLine($"{Environment.NewLine}Process finished with exit code {exitCode}"); $"{Environment.NewLine}Process finished with exit code {exitCode}"
);
}).SafeFireAndForget(); })
.SafeFireAndForget();
} }
// Callback for processes // Callback for processes
@ -448,7 +495,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
{ {
OnLoaded(); OnLoaded();
} }
private void RunningPackageOnStartupComplete(object? sender, string e) private void RunningPackageOnStartupComplete(object? sender, string e)
{ {
webUiUrl = e.Replace("0.0.0.0", "127.0.0.1"); webUiUrl = e.Replace("0.0.0.0", "127.0.0.1");
@ -463,39 +510,45 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
if (e.CloseReason is WindowCloseReason.WindowClosing) if (e.CloseReason is WindowCloseReason.WindowClosing)
{ {
e.Cancel = true; e.Cancel = true;
var dialog = CreateExitConfirmDialog(); var dialog = CreateExitConfirmDialog();
Dispatcher.UIThread.InvokeAsync(async () => Dispatcher.UIThread
{ .InvokeAsync(async () =>
if ((TaskDialogStandardResult)
await dialog.ShowAsync(true) == TaskDialogStandardResult.Yes)
{ {
App.Services.GetRequiredService<MainWindow>().Hide(); if (
App.Shutdown(); (TaskDialogStandardResult)await dialog.ShowAsync(true)
} == TaskDialogStandardResult.Yes
}).SafeFireAndForget(); )
{
App.Services.GetRequiredService<MainWindow>().Hide();
App.Shutdown();
}
})
.SafeFireAndForget();
} }
} }
} }
private static TaskDialog CreateExitConfirmDialog() private static TaskDialog CreateExitConfirmDialog()
{ {
var dialog = DialogHelper.CreateTaskDialog("Confirm Exit", var dialog = DialogHelper.CreateTaskDialog(
"Are you sure you want to exit? This will also close the currently running package."); "Confirm Exit",
"Are you sure you want to exit? This will also close the currently running package."
);
dialog.ShowProgressBar = false; dialog.ShowProgressBar = false;
dialog.FooterVisibility = TaskDialogFooterVisibility.Never; dialog.FooterVisibility = TaskDialogFooterVisibility.Never;
dialog.Buttons = new List<TaskDialogButton> dialog.Buttons = new List<TaskDialogButton>
{ {
new("Exit", TaskDialogStandardResult.Yes), new("Exit", TaskDialogStandardResult.Yes),
TaskDialogButton.CancelButton TaskDialogButton.CancelButton
}; };
dialog.Buttons[0].IsDefault = true; dialog.Buttons[0].IsDefault = true;
return dialog; return dialog;
} }
public void Dispose() public void Dispose()
{ {
RunningPackage?.Shutdown(); RunningPackage?.Shutdown();
@ -514,7 +567,7 @@ public partial class LaunchPageViewModel : PageViewModelBase, IDisposable, IAsyn
RunningPackage = null; RunningPackage = null;
} }
await Console.DisposeAsync(); await Console.DisposeAsync();
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
} }

182
StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs

@ -34,19 +34,29 @@ public partial class PackageCardViewModel : ProgressViewModel
private readonly INavigationService navigationService; private readonly INavigationService navigationService;
private readonly ServiceManager<ViewModelBase> vmFactory; private readonly ServiceManager<ViewModelBase> vmFactory;
[ObservableProperty] private InstalledPackage? package; [ObservableProperty]
[ObservableProperty] private string? cardImageSource; private InstalledPackage? package;
[ObservableProperty] private bool isUpdateAvailable;
[ObservableProperty] private string? installedVersion; [ObservableProperty]
[ObservableProperty] private bool isUnknownPackage; private string? cardImageSource;
[ObservableProperty]
private bool isUpdateAvailable;
[ObservableProperty]
private string? installedVersion;
[ObservableProperty]
private bool isUnknownPackage;
public PackageCardViewModel( public PackageCardViewModel(
ILogger<PackageCardViewModel> logger, ILogger<PackageCardViewModel> logger,
IPackageFactory packageFactory, IPackageFactory packageFactory,
INotificationService notificationService, INotificationService notificationService,
ISettingsManager settingsManager, ISettingsManager settingsManager,
INavigationService navigationService, INavigationService navigationService,
ServiceManager<ViewModelBase> vmFactory) ServiceManager<ViewModelBase> vmFactory
)
{ {
this.logger = logger; this.logger = logger;
this.packageFactory = packageFactory; this.packageFactory = packageFactory;
@ -70,10 +80,9 @@ public partial class PackageCardViewModel : ProgressViewModel
else else
{ {
IsUnknownPackage = false; IsUnknownPackage = false;
var basePackage = packageFactory[value.PackageName]; var basePackage = packageFactory[value.PackageName];
CardImageSource = basePackage?.PreviewImageUri.ToString() CardImageSource = basePackage?.PreviewImageUri.ToString() ?? Assets.NoImage.ToString();
?? Assets.NoImage.ToString();
InstalledVersion = value.Version?.DisplayVersion ?? "Unknown"; InstalledVersion = value.Version?.DisplayVersion ?? "Unknown";
} }
} }
@ -87,13 +96,13 @@ public partial class PackageCardViewModel : ProgressViewModel
{ {
if (Package == null) if (Package == null)
return; return;
settingsManager.Transaction(s => s.ActiveInstalledPackageId = Package.Id); settingsManager.Transaction(s => s.ActiveInstalledPackageId = Package.Id);
navigationService.NavigateTo<LaunchPageViewModel>(new BetterDrillInNavigationTransition()); navigationService.NavigateTo<LaunchPageViewModel>(new BetterDrillInNavigationTransition());
EventManager.Instance.OnPackageLaunchRequested(Package.Id); EventManager.Instance.OnPackageLaunchRequested(Package.Id);
} }
public async Task Uninstall() public async Task Uninstall()
{ {
if (Package?.LibraryPath == null) if (Package?.LibraryPath == null)
@ -104,7 +113,8 @@ public partial class PackageCardViewModel : ProgressViewModel
var dialog = new ContentDialog var dialog = new ContentDialog
{ {
Title = "Are you sure?", Title = "Are you sure?",
Content = "This will delete all folders in the package directory, including any generated images in that directory as well as any files you may have added.", Content =
"This will delete all folders in the package directory, including any generated images in that directory as well as any files you may have added.",
PrimaryButtonText = "Yes, delete it", PrimaryButtonText = "Yes, delete it",
CloseButtonText = "No, keep it", CloseButtonText = "No, keep it",
DefaultButton = ContentDialogButton.Primary DefaultButton = ContentDialogButton.Primary
@ -119,14 +129,20 @@ public partial class PackageCardViewModel : ProgressViewModel
var packagePath = new DirectoryPath(settingsManager.LibraryDir, Package.LibraryPath); var packagePath = new DirectoryPath(settingsManager.LibraryDir, Package.LibraryPath);
var deleteTask = packagePath.DeleteVerboseAsync(logger); var deleteTask = packagePath.DeleteVerboseAsync(logger);
var taskResult = await notificationService.TryAsync(deleteTask, var taskResult = await notificationService.TryAsync(
"Some files could not be deleted. Please close any open files in the package directory and try again."); deleteTask,
"Some files could not be deleted. Please close any open files in the package directory and try again."
);
if (taskResult.IsSuccessful) if (taskResult.IsSuccessful)
{ {
notificationService.Show(new Notification("Success", notificationService.Show(
$"Package {Package.DisplayName} uninstalled", new Notification(
NotificationType.Success)); "Success",
$"Package {Package.DisplayName} uninstalled",
NotificationType.Success
)
);
if (!IsUnknownPackage) if (!IsUnknownPackage)
{ {
@ -135,62 +151,74 @@ public partial class PackageCardViewModel : ProgressViewModel
settings.RemoveInstalledPackageAndUpdateActive(Package); settings.RemoveInstalledPackageAndUpdateActive(Package);
}); });
} }
EventManager.Instance.OnInstalledPackagesChanged(); EventManager.Instance.OnInstalledPackagesChanged();
} }
} }
} }
public async Task Update() public async Task Update()
{ {
if (Package is null || IsUnknownPackage) return; if (Package is null || IsUnknownPackage)
return;
var basePackage = packageFactory[Package.PackageName!]; var basePackage = packageFactory[Package.PackageName!];
if (basePackage == null) if (basePackage == null)
{ {
logger.LogWarning("Could not find package {SelectedPackagePackageName}", logger.LogWarning(
Package.PackageName); "Could not find package {SelectedPackagePackageName}",
notificationService.Show("Invalid Package type", Package.PackageName
);
notificationService.Show(
"Invalid Package type",
$"Package {Package.PackageName.ToRepr()} is not a valid package type", $"Package {Package.PackageName.ToRepr()} is not a valid package type",
NotificationType.Error); NotificationType.Error
);
return; return;
} }
var packageName = Package.DisplayName ?? Package.PackageName ?? ""; var packageName = Package.DisplayName ?? Package.PackageName ?? "";
Text = $"Updating {packageName}"; Text = $"Updating {packageName}";
IsIndeterminate = true; IsIndeterminate = true;
var progressId = Guid.NewGuid(); var progressId = Guid.NewGuid();
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId, EventManager.Instance.OnProgressChanged(
Package.DisplayName ?? Package.PackageName!, new ProgressItem(
new ProgressReport(0f, isIndeterminate: true, type: ProgressType.Update))); progressId,
Package.DisplayName ?? Package.PackageName!,
new ProgressReport(0f, isIndeterminate: true, type: ProgressType.Update)
)
);
try try
{ {
var progress = new Progress<ProgressReport>(progress => var progress = new Progress<ProgressReport>(progress =>
{ {
var percent = Convert.ToInt32(progress.Percentage); var percent = Convert.ToInt32(progress.Percentage);
Value = percent; Value = percent;
IsIndeterminate = progress.IsIndeterminate; IsIndeterminate = progress.IsIndeterminate;
Text = $"Updating {Package.DisplayName}"; Text = $"Updating {Package.DisplayName}";
EventManager.Instance.OnGlobalProgressChanged(percent); EventManager.Instance.OnGlobalProgressChanged(percent);
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId, EventManager.Instance.OnProgressChanged(
packageName, progress)); new ProgressItem(progressId, packageName, progress)
);
}); });
var torchVersion = Package.PreferredTorchVersion ?? var torchVersion =
basePackage.GetRecommendedTorchVersion(); Package.PreferredTorchVersion ?? basePackage.GetRecommendedTorchVersion();
var updateResult = await basePackage.Update(Package, torchVersion, progress); var updateResult = await basePackage.Update(Package, torchVersion, progress);
settingsManager.UpdatePackageVersionNumber(Package.Id, updateResult); settingsManager.UpdatePackageVersionNumber(Package.Id, updateResult);
notificationService.Show("Update complete", notificationService.Show(
"Update complete",
$"{Package.DisplayName} has been updated to the latest version.", $"{Package.DisplayName} has been updated to the latest version.",
NotificationType.Success); NotificationType.Success
);
await using (settingsManager.BeginTransaction()) await using (settingsManager.BeginTransaction())
{ {
Package.UpdateAvailable = false; Package.UpdateAvailable = false;
@ -198,17 +226,30 @@ public partial class PackageCardViewModel : ProgressViewModel
IsUpdateAvailable = false; IsUpdateAvailable = false;
InstalledVersion = updateResult.DisplayVersion ?? "Unknown"; InstalledVersion = updateResult.DisplayVersion ?? "Unknown";
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId, EventManager.Instance.OnProgressChanged(
packageName, new ProgressItem(
new ProgressReport(1f, "Update complete", type: ProgressType.Update))); progressId,
packageName,
new ProgressReport(1f, "Update complete", type: ProgressType.Update)
)
);
} }
catch (Exception e) catch (Exception e)
{ {
logger.LogError(e, "Error Updating Package ({PackageName})", basePackage.Name); logger.LogError(e, "Error Updating Package ({PackageName})", basePackage.Name);
notificationService.ShowPersistent($"Error Updating {Package.DisplayName}", e.Message, NotificationType.Error); notificationService.ShowPersistent(
EventManager.Instance.OnProgressChanged(new ProgressItem(progressId, $"Error Updating {Package.DisplayName}",
packageName, e.Message,
new ProgressReport(0f, "Update failed", type: ProgressType.Update), Failed: true)); NotificationType.Error
);
EventManager.Instance.OnProgressChanged(
new ProgressItem(
progressId,
packageName,
new ProgressReport(0f, "Update failed", type: ProgressType.Update),
Failed: true
)
);
} }
finally finally
{ {
@ -220,27 +261,23 @@ public partial class PackageCardViewModel : ProgressViewModel
public async Task Import() public async Task Import()
{ {
if (!IsUnknownPackage || Design.IsDesignMode) return; if (!IsUnknownPackage || Design.IsDesignMode)
return;
var viewModel = vmFactory.Get<PackageImportViewModel>(vm => var viewModel = vmFactory.Get<PackageImportViewModel>(vm =>
{ {
vm.PackagePath = vm.PackagePath = new DirectoryPath(
new DirectoryPath(Package?.FullPath ?? throw new InvalidOperationException()); Package?.FullPath ?? throw new InvalidOperationException()
);
}); });
var dialog = new TaskDialog var dialog = new TaskDialog
{ {
Content = new PackageImportDialog Content = new PackageImportDialog { DataContext = viewModel },
{
DataContext = viewModel
},
ShowProgressBar = false, ShowProgressBar = false,
Buttons = new List<TaskDialogButton> Buttons = new List<TaskDialogButton>
{ {
new(Resources.Action_Import, TaskDialogStandardResult.Yes) new(Resources.Action_Import, TaskDialogStandardResult.Yes) { IsDefault = true },
{
IsDefault = true
},
new(Resources.Action_Cancel, TaskDialogStandardResult.Cancel) new(Resources.Action_Cancel, TaskDialogStandardResult.Cancel)
} }
}; };
@ -257,7 +294,9 @@ public partial class PackageCardViewModel : ProgressViewModel
await using (new MinimumDelay(200, 300)) await using (new MinimumDelay(200, 300))
{ {
var result = await notificationService.TryAsync(viewModel.AddPackageWithCurrentInputs()); var result = await notificationService.TryAsync(
viewModel.AddPackageWithCurrentInputs()
);
if (result.IsSuccessful) if (result.IsSuccessful)
{ {
EventManager.Instance.OnInstalledPackagesChanged(); EventManager.Instance.OnInstalledPackagesChanged();
@ -269,18 +308,18 @@ public partial class PackageCardViewModel : ProgressViewModel
}; };
dialog.XamlRoot = App.VisualRoot; dialog.XamlRoot = App.VisualRoot;
await dialog.ShowAsync(true); await dialog.ShowAsync(true);
} }
public async Task OpenFolder() public async Task OpenFolder()
{ {
if (string.IsNullOrWhiteSpace(Package?.FullPath)) if (string.IsNullOrWhiteSpace(Package?.FullPath))
return; return;
await ProcessRunner.OpenFolderBrowser(Package.FullPath); await ProcessRunner.OpenFolderBrowser(Package.FullPath);
} }
private async Task<bool> HasUpdate() private async Task<bool> HasUpdate()
{ {
if (Package == null || IsUnknownPackage || Design.IsDesignMode) if (Package == null || IsUnknownPackage || Design.IsDesignMode)
@ -290,8 +329,9 @@ public partial class PackageCardViewModel : ProgressViewModel
if (basePackage == null) if (basePackage == null)
return false; return false;
var canCheckUpdate = Package.LastUpdateCheck == null || var canCheckUpdate =
Package.LastUpdateCheck < DateTime.Now.AddMinutes(-15); Package.LastUpdateCheck == null
|| Package.LastUpdateCheck < DateTime.Now.AddMinutes(-15);
if (!canCheckUpdate) if (!canCheckUpdate)
{ {

48
StabilityMatrix.Avalonia/ViewModels/PackageManagerViewModel.cs

@ -84,11 +84,14 @@ public partial class PackageManagerViewModel : PageViewModelBase
.Or(unknown) .Or(unknown)
.DeferUntilLoaded() .DeferUntilLoaded()
.Bind(Packages) .Bind(Packages)
.Transform(p => dialogFactory.Get<PackageCardViewModel>(vm => .Transform(
{ p =>
vm.Package = p; dialogFactory.Get<PackageCardViewModel>(vm =>
vm.OnLoadedAsync().SafeFireAndForget(); {
})) vm.Package = p;
vm.OnLoadedAsync().SafeFireAndForget();
})
)
.Bind(PackageCards) .Bind(PackageCards)
.Subscribe(); .Subscribe();
} }
@ -107,8 +110,11 @@ public partial class PackageManagerViewModel : PageViewModelBase
{ {
if (Design.IsDesignMode) if (Design.IsDesignMode)
return; return;
installedPackages.EditDiff(settingsManager.Settings.InstalledPackages, InstalledPackage.Comparer); installedPackages.EditDiff(
settingsManager.Settings.InstalledPackages,
InstalledPackage.Comparer
);
var currentUnknown = await Task.Run(IndexUnknownPackages); var currentUnknown = await Task.Run(IndexUnknownPackages);
unknownInstalledPackages.Edit(s => s.Load(currentUnknown)); unknownInstalledPackages.Edit(s => s.Load(currentUnknown));
@ -135,8 +141,11 @@ public partial class PackageManagerViewModel : PageViewModelBase
if (result == ContentDialogResult.Primary) if (result == ContentDialogResult.Primary)
{ {
var steps = viewModel.Steps; var steps = viewModel.Steps;
var packageModificationDialogViewModel = var packageModificationDialogViewModel = new PackageModificationDialogViewModel(
new PackageModificationDialogViewModel(packageModificationRunner, notificationService, steps); packageModificationRunner,
notificationService,
steps
);
dialog = new BetterContentDialog dialog = new BetterContentDialog
{ {
@ -146,7 +155,10 @@ public partial class PackageManagerViewModel : PageViewModelBase
IsPrimaryButtonEnabled = false, IsPrimaryButtonEnabled = false,
IsSecondaryButtonEnabled = false, IsSecondaryButtonEnabled = false,
IsFooterVisible = false, IsFooterVisible = false,
Content = new PackageModificationDialog {DataContext = packageModificationDialogViewModel} Content = new PackageModificationDialog
{
DataContext = packageModificationDialogViewModel
}
}; };
await dialog.ShowAsync(); await dialog.ShowAsync();
@ -156,22 +168,24 @@ public partial class PackageManagerViewModel : PageViewModelBase
} }
private IEnumerable<UnknownInstalledPackage> IndexUnknownPackages() private IEnumerable<UnknownInstalledPackage> IndexUnknownPackages()
{ {
var packageDir = new DirectoryPath(settingsManager.LibraryDir).JoinDir("Packages"); var packageDir = new DirectoryPath(settingsManager.LibraryDir).JoinDir("Packages");
if (!packageDir.Exists) if (!packageDir.Exists)
{ {
yield break; yield break;
} }
var currentPackages = settingsManager.Settings.InstalledPackages.ToImmutableArray(); var currentPackages = settingsManager.Settings.InstalledPackages.ToImmutableArray();
foreach (var subDir in packageDir.Info foreach (
.EnumerateDirectories() var subDir in packageDir.Info
.Select(info => new DirectoryPath(info))) .EnumerateDirectories()
.Select(info => new DirectoryPath(info))
)
{ {
var expectedLibraryPath = $"Packages{Path.DirectorySeparatorChar}{subDir.Name}"; var expectedLibraryPath = $"Packages{Path.DirectorySeparatorChar}{subDir.Name}";
// Skip if the package is already installed // Skip if the package is already installed
if (currentPackages.Any(p => p.LibraryPath == expectedLibraryPath)) if (currentPackages.Any(p => p.LibraryPath == expectedLibraryPath))
{ {

322
StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs

@ -43,7 +43,7 @@ namespace StabilityMatrix.Avalonia.ViewModels;
public partial class SettingsViewModel : PageViewModelBase public partial class SettingsViewModel : PageViewModelBase
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService; private readonly INotificationService notificationService;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly IPrerequisiteHelper prerequisiteHelper; private readonly IPrerequisiteHelper prerequisiteHelper;
@ -51,72 +51,74 @@ public partial class SettingsViewModel : PageViewModelBase
private readonly ServiceManager<ViewModelBase> dialogFactory; private readonly ServiceManager<ViewModelBase> dialogFactory;
private readonly ITrackedDownloadService trackedDownloadService; private readonly ITrackedDownloadService trackedDownloadService;
private readonly IModelIndexService modelIndexService; private readonly IModelIndexService modelIndexService;
public SharedState SharedState { get; } public SharedState SharedState { get; }
public override string Title => "Settings"; public override string Title => "Settings";
public override IconSource IconSource => new SymbolIconSource {Symbol = Symbol.Settings, IsFilled = true}; public override IconSource IconSource =>
new SymbolIconSource { Symbol = Symbol.Settings, IsFilled = true };
// ReSharper disable once MemberCanBeMadeStatic.Global // ReSharper disable once MemberCanBeMadeStatic.Global
public string AppVersion => $"Version {Compat.AppVersion}" + public string AppVersion =>
(Program.IsDebugBuild ? " (Debug)" : ""); $"Version {Compat.AppVersion}" + (Program.IsDebugBuild ? " (Debug)" : "");
// Theme section // Theme section
[ObservableProperty] private string? selectedTheme; [ObservableProperty]
private string? selectedTheme;
public IReadOnlyList<string> AvailableThemes { get; } = new[]
{ public IReadOnlyList<string> AvailableThemes { get; } = new[] { "Light", "Dark", "System", };
"Light",
"Dark",
"System",
};
[ObservableProperty] private CultureInfo selectedLanguage; [ObservableProperty]
private CultureInfo selectedLanguage;
// ReSharper disable once MemberCanBeMadeStatic.Global // ReSharper disable once MemberCanBeMadeStatic.Global
public IReadOnlyList<CultureInfo> AvailableLanguages => Cultures.SupportedCultures; public IReadOnlyList<CultureInfo> AvailableLanguages => Cultures.SupportedCultures;
public IReadOnlyList<float> AnimationScaleOptions { get; } = new[] public IReadOnlyList<float> AnimationScaleOptions { get; } =
{ new[] { 0f, 0.25f, 0.5f, 0.75f, 1f, 1.25f, 1.5f, 1.75f, 2f, };
0f,
0.25f, [ObservableProperty]
0.5f, private float selectedAnimationScale;
0.75f,
1f,
1.25f,
1.5f,
1.75f,
2f,
};
[ObservableProperty] private float selectedAnimationScale;
// Shared folder options // Shared folder options
[ObservableProperty] private bool removeSymlinksOnShutdown; [ObservableProperty]
private bool removeSymlinksOnShutdown;
// Integrations section // Integrations section
[ObservableProperty] private bool isDiscordRichPresenceEnabled; [ObservableProperty]
private bool isDiscordRichPresenceEnabled;
// Debug section // Debug section
[ObservableProperty] private string? debugPaths; [ObservableProperty]
[ObservableProperty] private string? debugCompatInfo; private string? debugPaths;
[ObservableProperty] private string? debugGpuInfo;
[ObservableProperty]
private string? debugCompatInfo;
[ObservableProperty]
private string? debugGpuInfo;
// Info section // Info section
private const int VersionTapCountThreshold = 7; private const int VersionTapCountThreshold = 7;
[ObservableProperty, NotifyPropertyChangedFor(nameof(VersionFlyoutText))] private int versionTapCount;
[ObservableProperty] private bool isVersionTapTeachingTipOpen; [ObservableProperty, NotifyPropertyChangedFor(nameof(VersionFlyoutText))]
public string VersionFlyoutText => $"You are {VersionTapCountThreshold - VersionTapCount} clicks away from enabling Debug options."; private int versionTapCount;
[ObservableProperty]
private bool isVersionTapTeachingTipOpen;
public string VersionFlyoutText =>
$"You are {VersionTapCountThreshold - VersionTapCount} clicks away from enabling Debug options.";
public SettingsViewModel( public SettingsViewModel(
INotificationService notificationService, INotificationService notificationService,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IPrerequisiteHelper prerequisiteHelper, IPrerequisiteHelper prerequisiteHelper,
IPyRunner pyRunner, IPyRunner pyRunner,
ServiceManager<ViewModelBase> dialogFactory, ServiceManager<ViewModelBase> dialogFactory,
SharedState sharedState, SharedState sharedState,
ITrackedDownloadService trackedDownloadService, ITrackedDownloadService trackedDownloadService,
IModelIndexService modelIndexService) IModelIndexService modelIndexService
)
{ {
this.notificationService = notificationService; this.notificationService = notificationService;
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
@ -127,29 +129,32 @@ public partial class SettingsViewModel : PageViewModelBase
this.modelIndexService = modelIndexService; this.modelIndexService = modelIndexService;
SharedState = sharedState; SharedState = sharedState;
SelectedTheme = settingsManager.Settings.Theme ?? AvailableThemes[1]; SelectedTheme = settingsManager.Settings.Theme ?? AvailableThemes[1];
SelectedLanguage = Cultures.GetSupportedCultureOrDefault(settingsManager.Settings.Language); SelectedLanguage = Cultures.GetSupportedCultureOrDefault(settingsManager.Settings.Language);
RemoveSymlinksOnShutdown = settingsManager.Settings.RemoveFolderLinksOnShutdown; RemoveSymlinksOnShutdown = settingsManager.Settings.RemoveFolderLinksOnShutdown;
SelectedAnimationScale = settingsManager.Settings.AnimationScale; SelectedAnimationScale = settingsManager.Settings.AnimationScale;
settingsManager.RelayPropertyFor(this, settingsManager.RelayPropertyFor(this, vm => vm.SelectedTheme, settings => settings.Theme);
vm => vm.SelectedTheme,
settings => settings.Theme); settingsManager.RelayPropertyFor(
this,
settingsManager.RelayPropertyFor(this,
vm => vm.IsDiscordRichPresenceEnabled, vm => vm.IsDiscordRichPresenceEnabled,
settings => settings.IsDiscordRichPresenceEnabled); settings => settings.IsDiscordRichPresenceEnabled
);
settingsManager.RelayPropertyFor(this,
settingsManager.RelayPropertyFor(
this,
vm => vm.SelectedAnimationScale, vm => vm.SelectedAnimationScale,
settings => settings.AnimationScale); settings => settings.AnimationScale
);
} }
partial void OnSelectedThemeChanged(string? value) partial void OnSelectedThemeChanged(string? value)
{ {
// In case design / tests // In case design / tests
if (Application.Current is null) return; if (Application.Current is null)
return;
// Change theme // Change theme
Application.Current.RequestedThemeVariant = value switch Application.Current.RequestedThemeVariant = value switch
{ {
@ -161,16 +166,16 @@ public partial class SettingsViewModel : PageViewModelBase
partial void OnSelectedLanguageChanged(CultureInfo? oldValue, CultureInfo newValue) partial void OnSelectedLanguageChanged(CultureInfo? oldValue, CultureInfo newValue)
{ {
if (oldValue is null || newValue.Name == Cultures.Current.Name) return; if (oldValue is null || newValue.Name == Cultures.Current.Name)
return;
// Set locale // Set locale
if (AvailableLanguages.Contains(newValue)) if (AvailableLanguages.Contains(newValue))
{ {
Logger.Info("Changing language from {Old} to {New}", Logger.Info("Changing language from {Old} to {New}", oldValue, newValue);
oldValue, newValue);
Cultures.TrySetSupportedCulture(newValue); Cultures.TrySetSupportedCulture(newValue);
settingsManager.Transaction(s => s.Language = newValue.Name); settingsManager.Transaction(s => s.Language = newValue.Name);
var dialog = new BetterContentDialog var dialog = new BetterContentDialog
{ {
Title = Resources.Label_RelaunchRequired, Title = Resources.Label_RelaunchRequired,
@ -191,11 +196,14 @@ public partial class SettingsViewModel : PageViewModelBase
} }
else else
{ {
Logger.Info("Requested invalid language change from {Old} to {New}", Logger.Info(
oldValue, newValue); "Requested invalid language change from {Old} to {New}",
oldValue,
newValue
);
} }
} }
partial void OnRemoveSymlinksOnShutdownChanged(bool value) partial void OnRemoveSymlinksOnShutdownChanged(bool value)
{ {
settingsManager.Transaction(s => s.RemoveFolderLinksOnShutdown = value); settingsManager.Transaction(s => s.RemoveFolderLinksOnShutdown = value);
@ -205,29 +213,30 @@ public partial class SettingsViewModel : PageViewModelBase
{ {
settingsManager.Transaction(s => s.InstalledModelHashes = new HashSet<string>()); settingsManager.Transaction(s => s.InstalledModelHashes = new HashSet<string>());
await Task.Run(() => settingsManager.IndexCheckpoints()); await Task.Run(() => settingsManager.IndexCheckpoints());
notificationService.Show("Checkpoint cache reset", "The checkpoint cache has been reset.", notificationService.Show(
NotificationType.Success); "Checkpoint cache reset",
"The checkpoint cache has been reset.",
NotificationType.Success
);
} }
#region Package Environment #region Package Environment
[RelayCommand] [RelayCommand]
private async Task OpenEnvVarsDialog() private async Task OpenEnvVarsDialog()
{ {
var viewModel = dialogFactory.Get<EnvVarsViewModel>(); var viewModel = dialogFactory.Get<EnvVarsViewModel>();
// Load current settings // Load current settings
var current = settingsManager.Settings.EnvironmentVariables var current =
?? new Dictionary<string, string>(); settingsManager.Settings.EnvironmentVariables ?? new Dictionary<string, string>();
viewModel.EnvVars = new ObservableCollection<EnvVarKeyPair>( viewModel.EnvVars = new ObservableCollection<EnvVarKeyPair>(
current.Select(kvp => new EnvVarKeyPair(kvp.Key, kvp.Value))); current.Select(kvp => new EnvVarKeyPair(kvp.Key, kvp.Value))
);
var dialog = new BetterContentDialog var dialog = new BetterContentDialog
{ {
Content = new EnvVarsDialog Content = new EnvVarsDialog { DataContext = viewModel },
{
DataContext = viewModel
},
PrimaryButtonText = "Save", PrimaryButtonText = "Save",
IsPrimaryButtonEnabled = true, IsPrimaryButtonEnabled = true,
CloseButtonText = "Cancel", CloseButtonText = "Cancel",
@ -275,7 +284,7 @@ public partial class SettingsViewModel : PageViewModelBase
dialog.PrimaryButtonText = "Ok"; dialog.PrimaryButtonText = "Ok";
await dialog.ShowAsync(); await dialog.ShowAsync();
} }
#endregion #endregion
#region System #region System
@ -288,27 +297,29 @@ public partial class SettingsViewModel : PageViewModelBase
{ {
if (!Compat.IsWindows) if (!Compat.IsWindows)
{ {
notificationService.Show( notificationService.Show("Not supported", "This feature is only supported on Windows.");
"Not supported", "This feature is only supported on Windows.");
return; return;
} }
await using var _ = new MinimumDelay(200, 300); await using var _ = new MinimumDelay(200, 300);
var shortcutDir = new DirectoryPath( var shortcutDir = new DirectoryPath(
Environment.GetFolderPath(Environment.SpecialFolder.StartMenu), Environment.GetFolderPath(Environment.SpecialFolder.StartMenu),
"Programs"); "Programs"
);
var shortcutLink = shortcutDir.JoinFile("Stability Matrix.lnk"); var shortcutLink = shortcutDir.JoinFile("Stability Matrix.lnk");
var appPath = Compat.AppCurrentPath; var appPath = Compat.AppCurrentPath;
var iconPath = shortcutDir.JoinFile("Stability Matrix.ico"); var iconPath = shortcutDir.JoinFile("Stability Matrix.ico");
await Assets.AppIcon.ExtractTo(iconPath); await Assets.AppIcon.ExtractTo(iconPath);
WindowsShortcuts.CreateShortcut( WindowsShortcuts.CreateShortcut(shortcutLink, appPath, iconPath, "Stability Matrix");
shortcutLink, appPath, iconPath, "Stability Matrix");
notificationService.Show(
notificationService.Show("Added to Start Menu", "Added to Start Menu",
"Stability Matrix has been added to the Start Menu.", NotificationType.Success); "Stability Matrix has been added to the Start Menu.",
NotificationType.Success
);
} }
/// <summary> /// <summary>
@ -320,54 +331,62 @@ public partial class SettingsViewModel : PageViewModelBase
{ {
if (!Compat.IsWindows) if (!Compat.IsWindows)
{ {
notificationService.Show( notificationService.Show("Not supported", "This feature is only supported on Windows.");
"Not supported", "This feature is only supported on Windows.");
return; return;
} }
// Confirmation dialog // Confirmation dialog
var dialog = new BetterContentDialog var dialog = new BetterContentDialog
{ {
Title = "This will create a shortcut for Stability Matrix in the Start Menu for all users", Title =
"This will create a shortcut for Stability Matrix in the Start Menu for all users",
Content = "You will be prompted for administrator privileges. Continue?", Content = "You will be prompted for administrator privileges. Continue?",
PrimaryButtonText = "Yes", PrimaryButtonText = "Yes",
CloseButtonText = "Cancel", CloseButtonText = "Cancel",
DefaultButton = ContentDialogButton.Primary DefaultButton = ContentDialogButton.Primary
}; };
if (await dialog.ShowAsync() != ContentDialogResult.Primary) if (await dialog.ShowAsync() != ContentDialogResult.Primary)
{ {
return; return;
} }
await using var _ = new MinimumDelay(200, 300); await using var _ = new MinimumDelay(200, 300);
var shortcutDir = new DirectoryPath( var shortcutDir = new DirectoryPath(
Environment.GetFolderPath(Environment.SpecialFolder.CommonStartMenu), Environment.GetFolderPath(Environment.SpecialFolder.CommonStartMenu),
"Programs"); "Programs"
);
var shortcutLink = shortcutDir.JoinFile("Stability Matrix.lnk"); var shortcutLink = shortcutDir.JoinFile("Stability Matrix.lnk");
var appPath = Compat.AppCurrentPath; var appPath = Compat.AppCurrentPath;
var iconPath = shortcutDir.JoinFile("Stability Matrix.ico"); var iconPath = shortcutDir.JoinFile("Stability Matrix.ico");
// We can't directly write to the targets, so extract to temporary directory first // We can't directly write to the targets, so extract to temporary directory first
using var tempDir = new TempDirectoryPath(); using var tempDir = new TempDirectoryPath();
await Assets.AppIcon.ExtractTo(tempDir.JoinFile("Stability Matrix.ico")); await Assets.AppIcon.ExtractTo(tempDir.JoinFile("Stability Matrix.ico"));
WindowsShortcuts.CreateShortcut( WindowsShortcuts.CreateShortcut(
tempDir.JoinFile("Stability Matrix.lnk"), appPath, iconPath, tempDir.JoinFile("Stability Matrix.lnk"),
"Stability Matrix"); appPath,
iconPath,
"Stability Matrix"
);
// Move to target // Move to target
try try
{ {
var moveLinkResult = await WindowsElevated.MoveFiles( var moveLinkResult = await WindowsElevated.MoveFiles(
(tempDir.JoinFile("Stability Matrix.lnk"), shortcutLink), (tempDir.JoinFile("Stability Matrix.lnk"), shortcutLink),
(tempDir.JoinFile("Stability Matrix.ico"), iconPath)); (tempDir.JoinFile("Stability Matrix.ico"), iconPath)
);
if (moveLinkResult != 0) if (moveLinkResult != 0)
{ {
notificationService.ShowPersistent("Failed to create shortcut", $"Could not copy shortcut", notificationService.ShowPersistent(
NotificationType.Error); "Failed to create shortcut",
$"Could not copy shortcut",
NotificationType.Error
);
} }
} }
catch (Win32Exception e) catch (Win32Exception e)
@ -377,9 +396,12 @@ public partial class SettingsViewModel : PageViewModelBase
notificationService.Show("Could not create shortcut", "", NotificationType.Warning); notificationService.Show("Could not create shortcut", "", NotificationType.Warning);
return; return;
} }
notificationService.Show("Added to Start Menu", notificationService.Show(
"Stability Matrix has been added to the Start Menu for all users.", NotificationType.Success); "Added to Start Menu",
"Stability Matrix has been added to the Start Menu for all users.",
NotificationType.Success
);
} }
public async Task PickNewDataDirectory() public async Task PickNewDataDirectory()
@ -390,10 +412,7 @@ public partial class SettingsViewModel : PageViewModelBase
IsPrimaryButtonEnabled = false, IsPrimaryButtonEnabled = false,
IsSecondaryButtonEnabled = false, IsSecondaryButtonEnabled = false,
IsFooterVisible = false, IsFooterVisible = false,
Content = new SelectDataDirectoryDialog Content = new SelectDataDirectoryDialog { DataContext = viewModel }
{
DataContext = viewModel
}
}; };
var result = await dialog.ShowAsync(); var result = await dialog.ShowAsync();
@ -409,7 +428,7 @@ public partial class SettingsViewModel : PageViewModelBase
{ {
settingsManager.SetLibraryPath(viewModel.DataDirectory); settingsManager.SetLibraryPath(viewModel.DataDirectory);
} }
// Restart // Restart
var restartDialog = new BetterContentDialog var restartDialog = new BetterContentDialog
{ {
@ -420,14 +439,14 @@ public partial class SettingsViewModel : PageViewModelBase
IsSecondaryButtonEnabled = false, IsSecondaryButtonEnabled = false,
}; };
await restartDialog.ShowAsync(); await restartDialog.ShowAsync();
Process.Start(Compat.AppCurrentPath); Process.Start(Compat.AppCurrentPath);
App.Shutdown(); App.Shutdown();
} }
} }
#endregion #endregion
#region Debug Section #region Debug Section
public void LoadDebugInfo() public void LoadDebugInfo()
{ {
@ -443,12 +462,12 @@ public partial class SettingsViewModel : PageViewModelBase
AppData Directory [SpecialFolder.ApplicationData] AppData Directory [SpecialFolder.ApplicationData]
"{appData}" "{appData}"
"""; """;
// 1. Check portable mode // 1. Check portable mode
var appDir = Compat.AppCurrentDir; var appDir = Compat.AppCurrentDir;
var expectedPortableFile = Path.Combine(appDir, "Data", ".sm-portable"); var expectedPortableFile = Path.Combine(appDir, "Data", ".sm-portable");
var isPortableMode = File.Exists(expectedPortableFile); var isPortableMode = File.Exists(expectedPortableFile);
DebugCompatInfo = $""" DebugCompatInfo = $"""
Platform: {Compat.Platform} Platform: {Compat.Platform}
AppData: {Compat.AppData} AppData: {Compat.AppData}
@ -461,24 +480,27 @@ public partial class SettingsViewModel : PageViewModelBase
IsLibraryDirSet = {settingsManager.IsLibraryDirSet} IsLibraryDirSet = {settingsManager.IsLibraryDirSet}
IsPortableMode = {settingsManager.IsPortableMode} IsPortableMode = {settingsManager.IsPortableMode}
"""; """;
// Get Gpu info // Get Gpu info
var gpuInfo = ""; var gpuInfo = "";
foreach (var (i, gpu) in HardwareHelper.IterGpuInfo().Enumerate()) foreach (var (i, gpu) in HardwareHelper.IterGpuInfo().Enumerate())
{ {
gpuInfo += $"[{i+1}] {gpu}\n"; gpuInfo += $"[{i + 1}] {gpu}\n";
} }
DebugGpuInfo = gpuInfo; DebugGpuInfo = gpuInfo;
} }
// Debug buttons // Debug buttons
[RelayCommand] [RelayCommand]
private void DebugNotification() private void DebugNotification()
{ {
notificationService.Show(new Notification( notificationService.Show(
title: "Test Notification", new Notification(
message: "Here is some message", title: "Test Notification",
type: NotificationType.Information)); message: "Here is some message",
type: NotificationType.Information
)
);
} }
[RelayCommand] [RelayCommand]
@ -493,8 +515,7 @@ public partial class SettingsViewModel : PageViewModelBase
}; };
var result = await dialog.ShowAsync(); var result = await dialog.ShowAsync();
notificationService.Show(new Notification("Content dialog closed", notificationService.Show(new Notification("Content dialog closed", $"Result: {result}"));
$"Result: {result}"));
} }
[RelayCommand] [RelayCommand]
@ -509,20 +530,14 @@ public partial class SettingsViewModel : PageViewModelBase
{ {
await modelIndexService.RefreshIndex(); await modelIndexService.RefreshIndex();
} }
[RelayCommand] [RelayCommand]
private async Task DebugTrackedDownload() private async Task DebugTrackedDownload()
{ {
var textFields = new TextBoxField[] var textFields = new TextBoxField[]
{ {
new() new() { Label = "Url", },
{ new() { Label = "File path" }
Label = "Url",
},
new()
{
Label = "File path"
}
}; };
var dialog = DialogHelper.CreateTextEntryDialog("Add download", "", textFields); var dialog = DialogHelper.CreateTextEntryDialog("Add download", "", textFields);
@ -542,10 +557,11 @@ public partial class SettingsViewModel : PageViewModelBase
public void OnVersionClick() public void OnVersionClick()
{ {
// Ignore if already enabled // Ignore if already enabled
if (SharedState.IsDebugMode) return; if (SharedState.IsDebugMode)
return;
VersionTapCount++; VersionTapCount++;
switch (VersionTapCount) switch (VersionTapCount)
{ {
// Reached required threshold // Reached required threshold
@ -555,7 +571,9 @@ public partial class SettingsViewModel : PageViewModelBase
// Enable debug options // Enable debug options
SharedState.IsDebugMode = true; SharedState.IsDebugMode = true;
notificationService.Show( notificationService.Show(
"Debug options enabled", "Warning: Improper use may corrupt application state or cause loss of data."); "Debug options enabled",
"Warning: Improper use may corrupt application state or cause loss of data."
);
VersionTapCount = 0; VersionTapCount = 0;
break; break;
} }
@ -579,8 +597,11 @@ public partial class SettingsViewModel : PageViewModelBase
} }
catch (Exception e) catch (Exception e)
{ {
notificationService.Show("Failed to read licenses information", notificationService.Show(
$"{e}", NotificationType.Error); "Failed to read licenses information",
$"{e}",
NotificationType.Error
);
} }
} }
@ -588,15 +609,17 @@ public partial class SettingsViewModel : PageViewModelBase
{ {
// Read licenses.json // Read licenses.json
using var reader = new StreamReader(Assets.LicensesJson.Open()); using var reader = new StreamReader(Assets.LicensesJson.Open());
var licenses = JsonSerializer var licenses =
.Deserialize<IReadOnlyList<LicenseInfo>>(reader.ReadToEnd()) ?? JsonSerializer.Deserialize<IReadOnlyList<LicenseInfo>>(reader.ReadToEnd())
throw new InvalidOperationException("Failed to read licenses.json"); ?? throw new InvalidOperationException("Failed to read licenses.json");
// Generate markdown // Generate markdown
var builder = new StringBuilder(); var builder = new StringBuilder();
foreach (var license in licenses) foreach (var license in licenses)
{ {
builder.AppendLine($"## [{license.PackageName}]({license.PackageUrl}) by {string.Join(", ", license.Authors)}"); builder.AppendLine(
$"## [{license.PackageName}]({license.PackageUrl}) by {string.Join(", ", license.Authors)}"
);
builder.AppendLine(); builder.AppendLine();
builder.AppendLine(license.Description); builder.AppendLine(license.Description);
builder.AppendLine(); builder.AppendLine();
@ -608,5 +631,4 @@ public partial class SettingsViewModel : PageViewModelBase
} }
#endregion #endregion
} }

18
StabilityMatrix.Avalonia/Views/Dialogs/PackageModificationDialog.axaml.cs

@ -15,32 +15,34 @@ public partial class PackageModificationDialog : UserControlBase
public PackageModificationDialog() public PackageModificationDialog()
{ {
InitializeComponent(); InitializeComponent();
var editor = this.FindControl<TextEditor>("Console"); var editor = this.FindControl<TextEditor>("Console");
if (editor is not null) if (editor is not null)
{ {
var options = new RegistryOptions(ThemeName.DarkPlus); var options = new RegistryOptions(ThemeName.DarkPlus);
// Config hyperlinks // Config hyperlinks
editor.TextArea.Options.EnableHyperlinks = true; editor.TextArea.Options.EnableHyperlinks = true;
editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false; editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false;
editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral; editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral;
var textMate = editor.InstallTextMate(options); var textMate = editor.InstallTextMate(options);
var scope = options.GetScopeByLanguageId("log"); var scope = options.GetScopeByLanguageId("log");
if (scope is null) throw new InvalidOperationException("Scope is null"); if (scope is null)
throw new InvalidOperationException("Scope is null");
textMate.SetGrammar(scope); textMate.SetGrammar(scope);
textMate.SetTheme(options.LoadTheme(ThemeName.DarkPlus)); textMate.SetTheme(options.LoadTheme(ThemeName.DarkPlus));
} }
EventManager.Instance.ScrollToBottomRequested += (_, _) => EventManager.Instance.ScrollToBottomRequested += (_, _) =>
{ {
Dispatcher.UIThread.Invoke(() => Dispatcher.UIThread.Invoke(() =>
{ {
var editor = this.FindControl<TextEditor>("Console"); var editor = this.FindControl<TextEditor>("Console");
if (editor?.Document == null) return; if (editor?.Document == null)
return;
var line = Math.Max(editor.Document.LineCount - 5, 1); var line = Math.Max(editor.Document.LineCount - 5, 1);
editor.ScrollToLine(line); editor.ScrollToLine(line);
}); });

16
StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs

@ -14,7 +14,7 @@ namespace StabilityMatrix.Avalonia.Views;
public partial class LaunchPageView : UserControlBase public partial class LaunchPageView : UserControlBase
{ {
private const int LineOffset = 5; private const int LineOffset = 5;
public LaunchPageView() public LaunchPageView()
{ {
InitializeComponent(); InitializeComponent();
@ -22,17 +22,18 @@ public partial class LaunchPageView : UserControlBase
if (editor is not null) if (editor is not null)
{ {
var options = new RegistryOptions(ThemeName.DarkPlus); var options = new RegistryOptions(ThemeName.DarkPlus);
// Config hyperlinks // Config hyperlinks
editor.TextArea.Options.EnableHyperlinks = true; editor.TextArea.Options.EnableHyperlinks = true;
editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false; editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false;
editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral; editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral;
var textMate = editor.InstallTextMate(options); var textMate = editor.InstallTextMate(options);
var scope = options.GetScopeByLanguageId("log"); var scope = options.GetScopeByLanguageId("log");
if (scope is null) throw new InvalidOperationException("Scope is null"); if (scope is null)
throw new InvalidOperationException("Scope is null");
textMate.SetGrammar(scope); textMate.SetGrammar(scope);
textMate.SetTheme(options.LoadTheme(ThemeName.DarkPlus)); textMate.SetTheme(options.LoadTheme(ThemeName.DarkPlus));
} }
@ -55,7 +56,8 @@ public partial class LaunchPageView : UserControlBase
Dispatcher.UIThread.Invoke(() => Dispatcher.UIThread.Invoke(() =>
{ {
var editor = this.FindControl<TextEditor>("Console"); var editor = this.FindControl<TextEditor>("Console");
if (editor?.Document == null) return; if (editor?.Document == null)
return;
var line = Math.Max(editor.Document.LineCount - LineOffset, 1); var line = Math.Max(editor.Document.LineCount - LineOffset, 1);
editor.ScrollToLine(line); editor.ScrollToLine(line);
}); });

1
StabilityMatrix.Avalonia/Views/MainWindow.axaml

@ -16,6 +16,7 @@
Width="1100" Width="1100"
Height="750" Height="750"
Title="Stability Matrix" Title="Stability Matrix"
FontFamily="San Francisco, Segoe UI, Helvetica Neue, Helvetica, Arial, sans-serif"
x:Class="StabilityMatrix.Avalonia.Views.MainWindow"> x:Class="StabilityMatrix.Avalonia.Views.MainWindow">
<Grid RowDefinitions="Auto,Auto,*"> <Grid RowDefinitions="Auto,Auto,*">

3
StabilityMatrix.Core/Database/ILiteDbContext.cs

@ -7,13 +7,12 @@ namespace StabilityMatrix.Core.Database;
public interface ILiteDbContext : IDisposable public interface ILiteDbContext : IDisposable
{ {
LiteDatabaseAsync Database { get; } LiteDatabaseAsync Database { get; }
ILiteCollectionAsync<CivitModel> CivitModels { get; } ILiteCollectionAsync<CivitModel> CivitModels { get; }
ILiteCollectionAsync<CivitModelVersion> CivitModelVersions { get; } ILiteCollectionAsync<CivitModelVersion> CivitModelVersions { get; }
ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache { get; } ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache { get; }
ILiteCollectionAsync<LocalModelFile> LocalModelFiles { get; } ILiteCollectionAsync<LocalModelFile> LocalModelFiles { get; }
Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(string hashBlake3); Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(string hashBlake3);
Task<bool> UpsertCivitModelAsync(CivitModel civitModel); Task<bool> UpsertCivitModelAsync(CivitModel civitModel);
Task<bool> UpsertCivitModelAsync(IEnumerable<CivitModel> civitModels); Task<bool> UpsertCivitModelAsync(IEnumerable<CivitModel> civitModels);

115
StabilityMatrix.Core/Database/LiteDbContext.cs

@ -21,18 +21,24 @@ public class LiteDbContext : ILiteDbContext
// Notification events // Notification events
public event EventHandler? CivitModelsChanged; public event EventHandler? CivitModelsChanged;
// Collections (Tables) // Collections (Tables)
public ILiteCollectionAsync<CivitModel> CivitModels => Database.GetCollection<CivitModel>("CivitModels"); public ILiteCollectionAsync<CivitModel> CivitModels =>
public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions => Database.GetCollection<CivitModelVersion>("CivitModelVersions"); Database.GetCollection<CivitModel>("CivitModels");
public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache => Database.GetCollection<CivitModelQueryCacheEntry>("CivitModelQueryCache"); public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions =>
public ILiteCollectionAsync<GithubCacheEntry> GithubCache => Database.GetCollection<GithubCacheEntry>("GithubCache"); Database.GetCollection<CivitModelVersion>("CivitModelVersions");
public ILiteCollectionAsync<LocalModelFile> LocalModelFiles => Database.GetCollection<LocalModelFile>("LocalModelFiles"); public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache =>
Database.GetCollection<CivitModelQueryCacheEntry>("CivitModelQueryCache");
public ILiteCollectionAsync<GithubCacheEntry> GithubCache =>
Database.GetCollection<GithubCacheEntry>("GithubCache");
public ILiteCollectionAsync<LocalModelFile> LocalModelFiles =>
Database.GetCollection<LocalModelFile>("LocalModelFiles");
public LiteDbContext( public LiteDbContext(
ILogger<LiteDbContext> logger, ILogger<LiteDbContext> logger,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IOptions<DebugOptions> debugOptions) IOptions<DebugOptions> debugOptions
)
{ {
this.logger = logger; this.logger = logger;
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
@ -42,7 +48,7 @@ public class LiteDbContext : ILiteDbContext
private LiteDatabaseAsync CreateDatabase() private LiteDatabaseAsync CreateDatabase()
{ {
LiteDatabaseAsync? db = null; LiteDatabaseAsync? db = null;
if (debugOptions.TempDatabase) if (debugOptions.TempDatabase)
{ {
db = new LiteDatabaseAsync(":temp:"); db = new LiteDatabaseAsync(":temp:");
@ -53,54 +59,74 @@ public class LiteDbContext : ILiteDbContext
try try
{ {
var dbPath = Path.Combine(settingsManager.LibraryDir, "StabilityMatrix.db"); var dbPath = Path.Combine(settingsManager.LibraryDir, "StabilityMatrix.db");
db = new LiteDatabaseAsync(new ConnectionString() db = new LiteDatabaseAsync(
{ new ConnectionString()
Filename = dbPath, {
Connection = ConnectionType.Shared, Filename = dbPath,
}); Connection = ConnectionType.Shared,
}
);
} }
catch (IOException e) catch (IOException e)
{ {
logger.LogWarning("Database in use or not accessible ({Message}), using temporary database", e.Message); logger.LogWarning(
"Database in use or not accessible ({Message}), using temporary database",
e.Message
);
} }
} }
// Fallback to temporary database // Fallback to temporary database
db ??= new LiteDatabaseAsync(":temp:"); db ??= new LiteDatabaseAsync(":temp:");
// Register reference fields // Register reference fields
LiteDBExtensions.Register<CivitModel, CivitModelVersion>(m => m.ModelVersions, "CivitModelVersions"); LiteDBExtensions.Register<CivitModel, CivitModelVersion>(
LiteDBExtensions.Register<CivitModelQueryCacheEntry, CivitModel>(e => e.Items, "CivitModels"); m => m.ModelVersions,
"CivitModelVersions"
);
LiteDBExtensions.Register<CivitModelQueryCacheEntry, CivitModel>(
e => e.Items,
"CivitModels"
);
return db; return db;
} }
public async Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(string hashBlake3) public async Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(
string hashBlake3
)
{ {
var version = await CivitModelVersions.Query() var version = await CivitModelVersions
.Where(mv => mv.Files! .Query()
.Select(f => f.Hashes) .Where(
.Select(hashes => hashes.BLAKE3) mv =>
.Any(hash => hash == hashBlake3)) mv.Files!
.Select(f => f.Hashes)
.Select(hashes => hashes.BLAKE3)
.Any(hash => hash == hashBlake3)
)
.FirstOrDefaultAsync() .FirstOrDefaultAsync()
.ConfigureAwait(false); .ConfigureAwait(false);
if (version is null) return (null, null); if (version is null)
return (null, null);
var model = await CivitModels.Query()
var model = await CivitModels
.Query()
.Include(m => m.ModelVersions) .Include(m => m.ModelVersions)
.Where(m => m.ModelVersions! .Where(m => m.ModelVersions!.Select(v => v.Id).Any(id => id == version.Id))
.Select(v => v.Id) .FirstOrDefaultAsync()
.Any(id => id == version.Id)) .ConfigureAwait(false);
.FirstOrDefaultAsync().ConfigureAwait(false);
return (model, version); return (model, version);
} }
public async Task<bool> UpsertCivitModelAsync(CivitModel civitModel) public async Task<bool> UpsertCivitModelAsync(CivitModel civitModel)
{ {
// Insert model versions first then model // Insert model versions first then model
var versionsUpdated = await CivitModelVersions.UpsertAsync(civitModel.ModelVersions).ConfigureAwait(false); var versionsUpdated = await CivitModelVersions
.UpsertAsync(civitModel.ModelVersions)
.ConfigureAwait(false);
var updated = await CivitModels.UpsertAsync(civitModel).ConfigureAwait(false); var updated = await CivitModels.UpsertAsync(civitModel).ConfigureAwait(false);
// Notify listeners on any change // Notify listeners on any change
var anyUpdated = versionsUpdated > 0 || updated; var anyUpdated = versionsUpdated > 0 || updated;
@ -110,7 +136,7 @@ public class LiteDbContext : ILiteDbContext
} }
return anyUpdated; return anyUpdated;
} }
public async Task<bool> UpsertCivitModelAsync(IEnumerable<CivitModel> civitModels) public async Task<bool> UpsertCivitModelAsync(IEnumerable<CivitModel> civitModels)
{ {
var civitModelsArray = civitModels.ToArray(); var civitModelsArray = civitModels.ToArray();
@ -126,7 +152,7 @@ public class LiteDbContext : ILiteDbContext
} }
return anyUpdated; return anyUpdated;
} }
// Add to cache // Add to cache
public async Task<bool> UpsertCivitModelQueryCacheEntryAsync(CivitModelQueryCacheEntry entry) public async Task<bool> UpsertCivitModelQueryCacheEntryAsync(CivitModelQueryCacheEntry entry)
{ {
@ -141,13 +167,14 @@ public class LiteDbContext : ILiteDbContext
public async Task<GithubCacheEntry?> GetGithubCacheEntry(string? cacheKey) public async Task<GithubCacheEntry?> GetGithubCacheEntry(string? cacheKey)
{ {
if (string.IsNullOrEmpty(cacheKey)) return null; if (string.IsNullOrEmpty(cacheKey))
return null;
if (await GithubCache.FindByIdAsync(cacheKey).ConfigureAwait(false) is { } result) if (await GithubCache.FindByIdAsync(cacheKey).ConfigureAwait(false) is { } result)
{ {
return result; return result;
} }
return null; return null;
} }
@ -162,9 +189,7 @@ public class LiteDbContext : ILiteDbContext
{ {
database.Dispose(); database.Dispose();
} }
catch (ObjectDisposedException) catch (ObjectDisposedException) { }
{
}
database = null; database = null;
} }

84
StabilityMatrix.Core/Helper/SharedFolders.cs

@ -20,7 +20,7 @@ public class SharedFolders : ISharedFolders
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.packageFactory = packageFactory; this.packageFactory = packageFactory;
} }
// Platform redirect for junctions / symlinks // Platform redirect for junctions / symlinks
private static void CreateLinkOrJunction(string junctionDir, string targetDir, bool overwrite) private static void CreateLinkOrJunction(string junctionDir, string targetDir, bool overwrite)
{ {
@ -36,8 +36,11 @@ public class SharedFolders : ISharedFolders
} }
} }
public static void SetupLinks(Dictionary<SharedFolderType, IReadOnlyList<string>> definitions, public static void SetupLinks(
DirectoryPath modelsDirectory, DirectoryPath installDirectory) Dictionary<SharedFolderType, IReadOnlyList<string>> definitions,
DirectoryPath modelsDirectory,
DirectoryPath installDirectory
)
{ {
foreach (var (folderType, relativePaths) in definitions) foreach (var (folderType, relativePaths) in definitions)
{ {
@ -63,7 +66,9 @@ public class SharedFolders : ISharedFolders
// Skip name collisions // Skip name collisions
if (File.Exists(sourceFile)) if (File.Exists(sourceFile))
{ {
Logger.Warn($"Skipping file {file.FullName} because it already exists in {sourceDir}"); Logger.Warn(
$"Skipping file {file.FullName} because it already exists in {sourceDir}"
);
continue; continue;
} }
destinationFile.Info.MoveTo(sourceFile); destinationFile.Info.MoveTo(sourceFile);
@ -81,19 +86,25 @@ public class SharedFolders : ISharedFolders
{ {
var modelsDirectory = new DirectoryPath(settingsManager.ModelsDirectory); var modelsDirectory = new DirectoryPath(settingsManager.ModelsDirectory);
var sharedFolders = basePackage.SharedFolders; var sharedFolders = basePackage.SharedFolders;
if (sharedFolders == null) return; if (sharedFolders == null)
return;
SetupLinks(sharedFolders, modelsDirectory, installDirectory); SetupLinks(sharedFolders, modelsDirectory, installDirectory);
} }
/// <summary> /// <summary>
/// Deletes junction links and remakes them. Unlike SetupLinksForPackage, /// Deletes junction links and remakes them. Unlike SetupLinksForPackage,
/// this will not copy files from the destination to the source. /// this will not copy files from the destination to the source.
/// </summary> /// </summary>
public static async Task UpdateLinksForPackage(BasePackage basePackage, DirectoryPath modelsDirectory, DirectoryPath installDirectory) public static async Task UpdateLinksForPackage(
BasePackage basePackage,
DirectoryPath modelsDirectory,
DirectoryPath installDirectory
)
{ {
var sharedFolders = basePackage.SharedFolders; var sharedFolders = basePackage.SharedFolders;
if (sharedFolders is null) return; if (sharedFolders is null)
return;
foreach (var (folderType, relativePaths) in sharedFolders) foreach (var (folderType, relativePaths) in sharedFolders)
{ {
foreach (var relativePath in relativePaths) foreach (var relativePath in relativePaths)
@ -117,7 +128,8 @@ public class SharedFolders : ISharedFolders
if (destinationDir.Info.LinkTarget == sourceDir) if (destinationDir.Info.LinkTarget == sourceDir)
{ {
Logger.Info( Logger.Info(
$"Skipped updating matching folder link ({destinationDir} -> ({sourceDir})"); $"Skipped updating matching folder link ({destinationDir} -> ({sourceDir})"
);
return; return;
} }
@ -131,8 +143,12 @@ public class SharedFolders : ISharedFolders
if (destinationDir.Info.EnumerateFileSystemInfos().Any()) if (destinationDir.Info.EnumerateFileSystemInfos().Any())
{ {
Logger.Info($"Moving files from {destinationDir} to {sourceDir}"); Logger.Info($"Moving files from {destinationDir} to {sourceDir}");
await FileTransfers.MoveAllFilesAndDirectories( await FileTransfers
destinationDir, sourceDir, overwriteIfHashMatches: true) .MoveAllFilesAndDirectories(
destinationDir,
sourceDir,
overwriteIfHashMatches: true
)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
@ -154,15 +170,16 @@ public class SharedFolders : ISharedFolders
{ {
return; return;
} }
foreach (var (_, relativePaths) in sharedFolders) foreach (var (_, relativePaths) in sharedFolders)
{ {
foreach (var relativePath in relativePaths) foreach (var relativePath in relativePaths)
{ {
var destination = Path.GetFullPath(Path.Combine(installPath, relativePath)); var destination = Path.GetFullPath(Path.Combine(installPath, relativePath));
// Delete the destination folder if it exists // Delete the destination folder if it exists
if (!Directory.Exists(destination)) continue; if (!Directory.Exists(destination))
continue;
Logger.Info($"Deleting junction target {destination}"); Logger.Info($"Deleting junction target {destination}");
Directory.Delete(destination, false); Directory.Delete(destination, false);
} }
@ -174,22 +191,32 @@ public class SharedFolders : ISharedFolders
var packages = settingsManager.Settings.InstalledPackages; var packages = settingsManager.Settings.InstalledPackages;
foreach (var package in packages) foreach (var package in packages)
{ {
if (package.PackageName == null) continue; if (package.PackageName == null)
continue;
var basePackage = packageFactory[package.PackageName]; var basePackage = packageFactory[package.PackageName];
if (basePackage == null) continue; if (basePackage == null)
if (package.LibraryPath == null) continue; continue;
if (package.LibraryPath == null)
continue;
try try
{ {
var sharedFolderMethod = package.PreferredSharedFolderMethod ?? var sharedFolderMethod =
basePackage.RecommendedSharedFolderMethod; package.PreferredSharedFolderMethod
basePackage.RemoveModelFolderLinks(package.FullPath, sharedFolderMethod) ?? basePackage.RecommendedSharedFolderMethod;
.GetAwaiter().GetResult(); basePackage
.RemoveModelFolderLinks(package.FullPath, sharedFolderMethod)
.GetAwaiter()
.GetResult();
} }
catch (Exception e) catch (Exception e)
{ {
Logger.Warn("Failed to remove links for package {Package} " + Logger.Warn(
"({DisplayName}): {Message}", package.PackageName, package.DisplayName, e.Message); "Failed to remove links for package {Package} " + "({DisplayName}): {Message}",
package.PackageName,
package.DisplayName,
e.Message
);
} }
} }
} }
@ -197,8 +224,9 @@ public class SharedFolders : ISharedFolders
public void SetupSharedModelFolders() public void SetupSharedModelFolders()
{ {
var modelsDir = settingsManager.ModelsDirectory; var modelsDir = settingsManager.ModelsDirectory;
if (string.IsNullOrWhiteSpace(modelsDir)) return; if (string.IsNullOrWhiteSpace(modelsDir))
return;
Directory.CreateDirectory(modelsDir); Directory.CreateDirectory(modelsDir);
var allSharedFolderTypes = Enum.GetValues<SharedFolderType>(); var allSharedFolderTypes = Enum.GetValues<SharedFolderType>();
foreach (var sharedFolder in allSharedFolderTypes) foreach (var sharedFolder in allSharedFolderTypes)

3
StabilityMatrix.Core/Models/Database/GitCommit.cs

@ -6,5 +6,6 @@ public class GitCommit
{ {
public string? Sha { get; set; } public string? Sha { get; set; }
[JsonIgnore] public string ShortSha => string.IsNullOrWhiteSpace(Sha) ? string.Empty : Sha[..7]; [JsonIgnore]
public string ShortSha => string.IsNullOrWhiteSpace(Sha) ? string.Empty : Sha[..7];
} }

7
StabilityMatrix.Core/Models/Database/LocalModelFile.cs

@ -22,7 +22,7 @@ public class LocalModelFile
/// Optional connected model information. /// Optional connected model information.
/// </summary> /// </summary>
public ConnectedModelInfo? ConnectedModelInfo { get; set; } public ConnectedModelInfo? ConnectedModelInfo { get; set; }
/// <summary> /// <summary>
/// Optional preview image relative path. /// Optional preview image relative path.
/// </summary> /// </summary>
@ -32,10 +32,11 @@ public class LocalModelFile
{ {
return Path.Combine(rootModelDirectory, RelativePath); return Path.Combine(rootModelDirectory, RelativePath);
} }
public string? GetPreviewImageFullPath(string rootModelDirectory) public string? GetPreviewImageFullPath(string rootModelDirectory)
{ {
return PreviewImageRelativePath == null ? null return PreviewImageRelativePath == null
? null
: Path.Combine(rootModelDirectory, PreviewImageRelativePath); : Path.Combine(rootModelDirectory, PreviewImageRelativePath);
} }

75
StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs

@ -10,6 +10,7 @@ namespace StabilityMatrix.Core.Models.FileInterfaces;
public class FilePath : FileSystemPath, IPathObject public class FilePath : FileSystemPath, IPathObject
{ {
private FileInfo? _info; private FileInfo? _info;
// ReSharper disable once MemberCanBePrivate.Global // ReSharper disable once MemberCanBePrivate.Global
[JsonIgnore] [JsonIgnore]
public FileInfo Info => _info ??= new FileInfo(FullPath); public FileInfo Info => _info ??= new FileInfo(FullPath);
@ -23,17 +24,16 @@ public class FilePath : FileSystemPath, IPathObject
return Info.Attributes.HasFlag(FileAttributes.ReparsePoint); return Info.Attributes.HasFlag(FileAttributes.ReparsePoint);
} }
} }
[JsonIgnore] [JsonIgnore]
public bool Exists => Info.Exists; public bool Exists => Info.Exists;
[JsonIgnore] [JsonIgnore]
public string Name => Info.Name; public string Name => Info.Name;
[JsonIgnore] [JsonIgnore]
public string NameWithoutExtension public string NameWithoutExtension => Path.GetFileNameWithoutExtension(Info.Name);
=> Path.GetFileNameWithoutExtension(Info.Name);
/// <summary> /// <summary>
/// Get the directory of the file. /// Get the directory of the file.
/// </summary> /// </summary>
@ -44,8 +44,7 @@ public class FilePath : FileSystemPath, IPathObject
{ {
try try
{ {
return Info.Directory == null ? null return Info.Directory == null ? null : new DirectoryPath(Info.Directory);
: new DirectoryPath(Info.Directory);
} }
catch (DirectoryNotFoundException) catch (DirectoryNotFoundException)
{ {
@ -54,32 +53,31 @@ public class FilePath : FileSystemPath, IPathObject
} }
} }
public FilePath(string path) : base(path) public FilePath(string path)
{ : base(path) { }
}
public FilePath(FileInfo fileInfo)
public FilePath(FileInfo fileInfo) : base(fileInfo.FullName) : base(fileInfo.FullName)
{ {
_info = fileInfo; _info = fileInfo;
} }
public FilePath(FileSystemPath path) : base(path) public FilePath(FileSystemPath path)
{ : base(path) { }
}
public FilePath(params string[] paths)
public FilePath(params string[] paths) : base(paths) : base(paths) { }
{
}
public long GetSize() public long GetSize()
{ {
Info.Refresh(); Info.Refresh();
return Info.Length; return Info.Length;
} }
public long GetSize(bool includeSymbolicLinks) public long GetSize(bool includeSymbolicLinks)
{ {
if (!includeSymbolicLinks && IsSymbolicLink) return 0; if (!includeSymbolicLinks && IsSymbolicLink)
return 0;
return GetSize(); return GetSize();
} }
@ -87,51 +85,51 @@ public class FilePath : FileSystemPath, IPathObject
{ {
return Task.Run(() => GetSize(includeSymbolicLinks)); return Task.Run(() => GetSize(includeSymbolicLinks));
} }
/// <summary> Creates an empty file. </summary> /// <summary> Creates an empty file. </summary>
public void Create() => File.Create(FullPath).Close(); public void Create() => File.Create(FullPath).Close();
/// <summary> Deletes the file </summary> /// <summary> Deletes the file </summary>
public void Delete() => File.Delete(FullPath); public void Delete() => File.Delete(FullPath);
// Methods specific to files // Methods specific to files
/// <summary> Read text </summary> /// <summary> Read text </summary>
public string ReadAllText() => File.ReadAllText(FullPath); public string ReadAllText() => File.ReadAllText(FullPath);
/// <summary> Read text asynchronously </summary> /// <summary> Read text asynchronously </summary>
public Task<string> ReadAllTextAsync(CancellationToken ct = default) public Task<string> ReadAllTextAsync(CancellationToken ct = default)
{ {
return File.ReadAllTextAsync(FullPath, ct); return File.ReadAllTextAsync(FullPath, ct);
} }
/// <summary> Write text </summary> /// <summary> Write text </summary>
public void WriteAllText(string text) => File.WriteAllText(FullPath, text, Encoding.UTF8); public void WriteAllText(string text) => File.WriteAllText(FullPath, text, Encoding.UTF8);
/// <summary> Write text asynchronously </summary> /// <summary> Write text asynchronously </summary>
public Task WriteAllTextAsync(string text, CancellationToken ct = default) public Task WriteAllTextAsync(string text, CancellationToken ct = default)
{ {
return File.WriteAllTextAsync(FullPath, text, Encoding.UTF8, ct); return File.WriteAllTextAsync(FullPath, text, Encoding.UTF8, ct);
} }
/// <summary> Read bytes </summary> /// <summary> Read bytes </summary>
public byte[] ReadAllBytes() => File.ReadAllBytes(FullPath); public byte[] ReadAllBytes() => File.ReadAllBytes(FullPath);
/// <summary> Read bytes asynchronously </summary> /// <summary> Read bytes asynchronously </summary>
public Task<byte[]> ReadAllBytesAsync(CancellationToken ct = default) public Task<byte[]> ReadAllBytesAsync(CancellationToken ct = default)
{ {
return File.ReadAllBytesAsync(FullPath, ct); return File.ReadAllBytesAsync(FullPath, ct);
} }
/// <summary> Write bytes </summary> /// <summary> Write bytes </summary>
public void WriteAllBytes(byte[] bytes) => File.WriteAllBytes(FullPath, bytes); public void WriteAllBytes(byte[] bytes) => File.WriteAllBytes(FullPath, bytes);
/// <summary> Write bytes asynchronously </summary> /// <summary> Write bytes asynchronously </summary>
public Task WriteAllBytesAsync(byte[] bytes, CancellationToken ct = default) public Task WriteAllBytesAsync(byte[] bytes, CancellationToken ct = default)
{ {
return File.WriteAllBytesAsync(FullPath, bytes, ct); return File.WriteAllBytesAsync(FullPath, bytes, ct);
} }
/// <summary> /// <summary>
/// Move the file to a directory. /// Move the file to a directory.
/// </summary> /// </summary>
@ -141,7 +139,7 @@ public class FilePath : FileSystemPath, IPathObject
// Return the new path // Return the new path
return destinationFile; return destinationFile;
} }
/// <summary> /// <summary>
/// Move the file to a directory. /// Move the file to a directory.
/// </summary> /// </summary>
@ -151,7 +149,7 @@ public class FilePath : FileSystemPath, IPathObject
// Return the new path // Return the new path
return directory.JoinFile(this); return directory.JoinFile(this);
} }
/// <summary> /// <summary>
/// Move the file to a target path. /// Move the file to a target path.
/// </summary> /// </summary>
@ -161,7 +159,7 @@ public class FilePath : FileSystemPath, IPathObject
// Return the new path // Return the new path
return destinationFile; return destinationFile;
} }
/// <summary> /// <summary>
/// Copy the file to a target path. /// Copy the file to a target path.
/// </summary> /// </summary>
@ -174,5 +172,6 @@ public class FilePath : FileSystemPath, IPathObject
// Implicit conversions to and from string // Implicit conversions to and from string
public static implicit operator string(FilePath path) => path.FullPath; public static implicit operator string(FilePath path) => path.FullPath;
public static implicit operator FilePath(string path) => new(path); public static implicit operator FilePath(string path) => new(path);
} }

91
StabilityMatrix.Core/Models/InstalledPackage.cs

@ -10,45 +10,48 @@ public class InstalledPackage : IJsonOnDeserialized
{ {
// Unique ID for the installation // Unique ID for the installation
public Guid Id { get; set; } public Guid Id { get; set; }
// User defined name // User defined name
public string? DisplayName { get; set; } public string? DisplayName { get; set; }
// Package name // Package name
public string? PackageName { get; set; } public string? PackageName { get; set; }
// Package version // Package version
[Obsolete("Use Version instead. (Kept for migration)")] [Obsolete("Use Version instead. (Kept for migration)")]
public string? PackageVersion { get; set; } public string? PackageVersion { get; set; }
[Obsolete("Use Version instead. (Kept for migration)")] [Obsolete("Use Version instead. (Kept for migration)")]
public string? InstalledBranch { get; set; } public string? InstalledBranch { get; set; }
[Obsolete("Use Version instead. (Kept for migration)")] [Obsolete("Use Version instead. (Kept for migration)")]
public string? DisplayVersion { get; set; } public string? DisplayVersion { get; set; }
public InstalledPackageVersion? Version { get; set; } public InstalledPackageVersion? Version { get; set; }
// Old type absolute path // Old type absolute path
[Obsolete("Use LibraryPath instead. (Kept for migration)")] [Obsolete("Use LibraryPath instead. (Kept for migration)")]
public string? Path { get; set; } public string? Path { get; set; }
/// <summary> /// <summary>
/// Relative path from the library root. /// Relative path from the library root.
/// </summary> /// </summary>
public string? LibraryPath { get; set; } public string? LibraryPath { get; set; }
/// <summary> /// <summary>
/// Full path to the package, using LibraryPath and GlobalConfig.LibraryDir. /// Full path to the package, using LibraryPath and GlobalConfig.LibraryDir.
/// </summary> /// </summary>
[JsonIgnore] [JsonIgnore]
public string? FullPath => LibraryPath != null ? System.IO.Path.Combine(GlobalConfig.LibraryDir, LibraryPath) : null; public string? FullPath =>
LibraryPath != null ? System.IO.Path.Combine(GlobalConfig.LibraryDir, LibraryPath) : null;
public string? LaunchCommand { get; set; } public string? LaunchCommand { get; set; }
public List<LaunchOption>? LaunchArgs { get; set; } public List<LaunchOption>? LaunchArgs { get; set; }
public DateTimeOffset? LastUpdateCheck { get; set; } public DateTimeOffset? LastUpdateCheck { get; set; }
public bool UpdateAvailable { get; set; } public bool UpdateAvailable { get; set; }
public TorchVersion? PreferredTorchVersion { get; set; } public TorchVersion? PreferredTorchVersion { get; set; }
public SharedFolderMethod? PreferredSharedFolderMethod { get; set; } public SharedFolderMethod? PreferredSharedFolderMethod { get; set; }
/// <summary> /// <summary>
/// Get the path as a relative sub-path of the relative path. /// Get the path as a relative sub-path of the relative path.
/// If not a sub-path, return null. /// If not a sub-path, return null.
@ -57,14 +60,16 @@ public class InstalledPackage : IJsonOnDeserialized
{ {
var relativePath = System.IO.Path.GetRelativePath(relativeTo, path); var relativePath = System.IO.Path.GetRelativePath(relativeTo, path);
// GetRelativePath returns the path if it's not relative // GetRelativePath returns the path if it's not relative
if (relativePath == path) return null; if (relativePath == path)
return null;
// Further check if the path is a sub-path of the library // Further check if the path is a sub-path of the library
var isSubPath = relativePath != "." var isSubPath =
&& relativePath != ".." relativePath != "."
&& !relativePath.StartsWith(".." + System.IO.Path.DirectorySeparatorChar) && relativePath != ".."
&& !System.IO.Path.IsPathRooted(relativePath); && !relativePath.StartsWith(".." + System.IO.Path.DirectorySeparatorChar)
&& !System.IO.Path.IsPathRooted(relativePath);
return isSubPath ? relativePath : null; return isSubPath ? relativePath : null;
} }
/// <summary> /// <summary>
/// Migrates the old Path to the new LibraryPath. /// Migrates the old Path to the new LibraryPath.
@ -76,12 +81,13 @@ public class InstalledPackage : IJsonOnDeserialized
#pragma warning disable CS0618 #pragma warning disable CS0618
var oldPath = Path; var oldPath = Path;
#pragma warning restore CS0618 #pragma warning restore CS0618
if (oldPath == null) return false; if (oldPath == null)
return false;
// Check if the path is a sub-path of the library // Check if the path is a sub-path of the library
var library = libraryDirectory ?? GlobalConfig.LibraryDir; var library = libraryDirectory ?? GlobalConfig.LibraryDir;
var relativePath = GetSubPath(library, oldPath); var relativePath = GetSubPath(library, oldPath);
// If so we migrate without any IO operations // If so we migrate without any IO operations
if (relativePath != null) if (relativePath != null)
{ {
@ -105,8 +111,9 @@ public class InstalledPackage : IJsonOnDeserialized
#pragma warning disable CS0618 #pragma warning disable CS0618
var oldPath = Path; var oldPath = Path;
#pragma warning restore CS0618 #pragma warning restore CS0618
if (oldPath == null) return false; if (oldPath == null)
return false;
// Check if the path is a sub-path of the library // Check if the path is a sub-path of the library
var library = libraryDirectory ?? GlobalConfig.LibraryDir; var library = libraryDirectory ?? GlobalConfig.LibraryDir;
var relativePath = GetSubPath(library, oldPath); var relativePath = GetSubPath(library, oldPath);
@ -123,7 +130,8 @@ public class InstalledPackage : IJsonOnDeserialized
#pragma warning disable CS0618 #pragma warning disable CS0618
var oldPath = Path; var oldPath = Path;
#pragma warning restore CS0618 #pragma warning restore CS0618
if (oldPath == null) return; if (oldPath == null)
return;
var libDir = libraryDirectory ?? GlobalConfig.LibraryDir; var libDir = libraryDirectory ?? GlobalConfig.LibraryDir;
// if old package Path is same as new library, return // if old package Path is same as new library, return
@ -136,27 +144,31 @@ public class InstalledPackage : IJsonOnDeserialized
LibraryPath = System.IO.Path.Combine("Packages", DisplayName); LibraryPath = System.IO.Path.Combine("Packages", DisplayName);
return; return;
} }
// Try using pure migration first // Try using pure migration first
if (TryPureMigratePath(libraryDirectory)) return; if (TryPureMigratePath(libraryDirectory))
return;
// If not, we need to move the package directory // If not, we need to move the package directory
var packageFolderName = new DirectoryInfo(oldPath).Name; var packageFolderName = new DirectoryInfo(oldPath).Name;
// Get the new Library/Packages path // Get the new Library/Packages path
var library = libraryDirectory ?? GlobalConfig.LibraryDir; var library = libraryDirectory ?? GlobalConfig.LibraryDir;
var newPackagesDir = System.IO.Path.Combine(library, "Packages"); var newPackagesDir = System.IO.Path.Combine(library, "Packages");
// Get the new target path // Get the new target path
var newPackagePath = System.IO.Path.Combine(newPackagesDir, packageFolderName); var newPackagePath = System.IO.Path.Combine(newPackagesDir, packageFolderName);
// Ensure it is not already there, if so, add a suffix until it's not // Ensure it is not already there, if so, add a suffix until it's not
var suffix = 2; var suffix = 2;
while (Directory.Exists(newPackagePath)) while (Directory.Exists(newPackagePath))
{ {
newPackagePath = System.IO.Path.Combine(newPackagesDir, $"{packageFolderName}-{suffix}"); newPackagePath = System.IO.Path.Combine(
newPackagesDir,
$"{packageFolderName}-{suffix}"
);
suffix++; suffix++;
} }
// Move the package directory // Move the package directory
await Task.Run(() => Utilities.CopyDirectory(oldPath, newPackagePath, true)); await Task.Run(() => Utilities.CopyDirectory(oldPath, newPackagePath, true));
@ -169,7 +181,7 @@ public class InstalledPackage : IJsonOnDeserialized
public static IEqualityComparer<InstalledPackage> Comparer { get; } = public static IEqualityComparer<InstalledPackage> Comparer { get; } =
new PropertyComparer<InstalledPackage>(p => p.Id); new PropertyComparer<InstalledPackage>(p => p.Id);
protected bool Equals(InstalledPackage other) protected bool Equals(InstalledPackage other)
{ {
return Id.Equals(other.Id); return Id.Equals(other.Id);
@ -177,9 +189,11 @@ public class InstalledPackage : IJsonOnDeserialized
public override bool Equals(object? obj) public override bool Equals(object? obj)
{ {
if (ReferenceEquals(null, obj)) return false; if (ReferenceEquals(null, obj))
if (ReferenceEquals(this, obj)) return true; return false;
return obj.GetType() == this.GetType() && Equals((InstalledPackage) obj); if (ReferenceEquals(this, obj))
return true;
return obj.GetType() == this.GetType() && Equals((InstalledPackage)obj);
} }
public override int GetHashCode() public override int GetHashCode()
@ -191,16 +205,15 @@ public class InstalledPackage : IJsonOnDeserialized
public void OnDeserialized() public void OnDeserialized()
{ {
// Handle version migration // Handle version migration
if (Version != null) if (Version != null)
return; return;
if (string.IsNullOrWhiteSpace(InstalledBranch) && !string.IsNullOrWhiteSpace(PackageVersion)) if (
string.IsNullOrWhiteSpace(InstalledBranch) && !string.IsNullOrWhiteSpace(PackageVersion)
)
{ {
// release mode // release mode
Version = new InstalledPackageVersion Version = new InstalledPackageVersion { InstalledReleaseVersion = PackageVersion };
{
InstalledReleaseVersion = PackageVersion
};
} }
else if (!string.IsNullOrWhiteSpace(PackageVersion)) else if (!string.IsNullOrWhiteSpace(PackageVersion))
{ {

13
StabilityMatrix.Core/Models/InstalledPackageVersion.cs

@ -12,9 +12,12 @@ public class InstalledPackageVersion
public bool IsReleaseMode => string.IsNullOrWhiteSpace(InstalledBranch); public bool IsReleaseMode => string.IsNullOrWhiteSpace(InstalledBranch);
[JsonIgnore] [JsonIgnore]
public string DisplayVersion => (IsReleaseMode public string DisplayVersion =>
? InstalledReleaseVersion (
: string.IsNullOrWhiteSpace(InstalledCommitSha) IsReleaseMode
? InstalledBranch ? InstalledReleaseVersion
: $"{InstalledBranch}@{InstalledCommitSha[..7]}") ?? string.Empty; : string.IsNullOrWhiteSpace(InstalledCommitSha)
? InstalledBranch
: $"{InstalledBranch}@{InstalledCommitSha[..7]}"
) ?? string.Empty;
} }

6
StabilityMatrix.Core/Models/PackageModification/AddInstalledPackageStep.cs

@ -8,8 +8,10 @@ public class AddInstalledPackageStep : IPackageStep
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly InstalledPackage newInstalledPackage; private readonly InstalledPackage newInstalledPackage;
public AddInstalledPackageStep(ISettingsManager settingsManager, public AddInstalledPackageStep(
InstalledPackage newInstalledPackage) ISettingsManager settingsManager,
InstalledPackage newInstalledPackage
)
{ {
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.newInstalledPackage = newInstalledPackage; this.newInstalledPackage = newInstalledPackage;

9
StabilityMatrix.Core/Models/PackageModification/DownloadPackageVersionStep.cs

@ -9,8 +9,11 @@ public class DownloadPackageVersionStep : IPackageStep
private readonly string installPath; private readonly string installPath;
private readonly DownloadPackageVersionOptions downloadOptions; private readonly DownloadPackageVersionOptions downloadOptions;
public DownloadPackageVersionStep(BasePackage package, string installPath, public DownloadPackageVersionStep(
DownloadPackageVersionOptions downloadOptions) BasePackage package,
string installPath,
DownloadPackageVersionOptions downloadOptions
)
{ {
this.package = package; this.package = package;
this.installPath = installPath; this.installPath = installPath;
@ -19,6 +22,6 @@ public class DownloadPackageVersionStep : IPackageStep
public Task ExecuteAsync(IProgress<ProgressReport>? progress = null) => public Task ExecuteAsync(IProgress<ProgressReport>? progress = null) =>
package.DownloadPackage(installPath, downloadOptions, progress); package.DownloadPackage(installPath, downloadOptions, progress);
public string ProgressTitle => "Downloading package..."; public string ProgressTitle => "Downloading package...";
} }

2
StabilityMatrix.Core/Models/PackageModification/IPackageModificationRunner.cs

@ -9,4 +9,4 @@ public interface IPackageModificationRunner
ProgressReport CurrentProgress { get; set; } ProgressReport CurrentProgress { get; set; }
IPackageStep? CurrentStep { get; set; } IPackageStep? CurrentStep { get; set; }
event EventHandler<ProgressReport>? ProgressChanged; event EventHandler<ProgressReport>? ProgressChanged;
} }

6
StabilityMatrix.Core/Models/PackageModification/InstallPackageStep.cs

@ -20,11 +20,7 @@ public class InstallPackageStep : IPackageStep
{ {
package.ConsoleOutput += (sender, output) => package.ConsoleOutput += (sender, output) =>
{ {
progress?.Report(new ProgressReport progress?.Report(new ProgressReport { IsIndeterminate = true, Message = output.Text });
{
IsIndeterminate = true,
Message = output.Text
});
}; };
await package.InstallPackage(installPath, torchVersion, progress).ConfigureAwait(false); await package.InstallPackage(installPath, torchVersion, progress).ConfigureAwait(false);
} }

5
StabilityMatrix.Core/Models/PackageModification/PackageModificationRunner.cs

@ -11,7 +11,7 @@ public class PackageModificationRunner : IPackageModificationRunner
CurrentProgress = report; CurrentProgress = report;
OnProgressChanged(report); OnProgressChanged(report);
}); });
IsRunning = true; IsRunning = true;
foreach (var step in steps) foreach (var step in steps)
{ {
@ -25,7 +25,8 @@ public class PackageModificationRunner : IPackageModificationRunner
public bool IsRunning { get; set; } public bool IsRunning { get; set; }
public ProgressReport CurrentProgress { get; set; } public ProgressReport CurrentProgress { get; set; }
public IPackageStep? CurrentStep { get; set; } public IPackageStep? CurrentStep { get; set; }
public event EventHandler<ProgressReport>? ProgressChanged; public event EventHandler<ProgressReport>? ProgressChanged;
protected virtual void OnProgressChanged(ProgressReport e) => ProgressChanged?.Invoke(this, e); protected virtual void OnProgressChanged(ProgressReport e) => ProgressChanged?.Invoke(this, e);
} }

12
StabilityMatrix.Core/Models/PackageModification/SetupModelFoldersStep.cs

@ -9,8 +9,11 @@ public class SetupModelFoldersStep : IPackageStep
private readonly SharedFolderMethod sharedFolderMethod; private readonly SharedFolderMethod sharedFolderMethod;
private readonly string installPath; private readonly string installPath;
public SetupModelFoldersStep(BasePackage package, SharedFolderMethod sharedFolderMethod, public SetupModelFoldersStep(
string installPath) BasePackage package,
SharedFolderMethod sharedFolderMethod,
string installPath
)
{ {
this.package = package; this.package = package;
this.sharedFolderMethod = sharedFolderMethod; this.sharedFolderMethod = sharedFolderMethod;
@ -19,8 +22,9 @@ public class SetupModelFoldersStep : IPackageStep
public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null) public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null)
{ {
progress?.Report(new ProgressReport(-1f, "Setting up shared folder links...", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Setting up shared folder links...", isIndeterminate: true)
);
await package.SetupModelFolders(installPath, sharedFolderMethod).ConfigureAwait(false); await package.SetupModelFolders(installPath, sharedFolderMethod).ConfigureAwait(false);
} }

11
StabilityMatrix.Core/Models/PackageModification/SetupPrerequisitesStep.cs

@ -19,15 +19,16 @@ public class SetupPrerequisitesStep : IPackageStep
{ {
// git, vcredist, etc... // git, vcredist, etc...
await prerequisiteHelper.InstallAllIfNecessary(progress).ConfigureAwait(false); await prerequisiteHelper.InstallAllIfNecessary(progress).ConfigureAwait(false);
// python stuff // python stuff
if (!PyRunner.PipInstalled || !PyRunner.VenvInstalled) if (!PyRunner.PipInstalled || !PyRunner.VenvInstalled)
{ {
progress?.Report(new ProgressReport(-1f, "Installing Python prerequisites...", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing Python prerequisites...", isIndeterminate: true)
);
await pyRunner.Initialize().ConfigureAwait(false); await pyRunner.Initialize().ConfigureAwait(false);
if (!PyRunner.PipInstalled) if (!PyRunner.PipInstalled)
{ {
await pyRunner.SetupPip().ConfigureAwait(false); await pyRunner.SetupPip().ConfigureAwait(false);

249
StabilityMatrix.Core/Models/Packages/A3WebUI.cs

@ -14,12 +14,12 @@ namespace StabilityMatrix.Core.Models.Packages;
public class A3WebUI : BaseGitPackage public class A3WebUI : BaseGitPackage
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
public override string Name => "stable-diffusion-webui"; public override string Name => "stable-diffusion-webui";
public override string DisplayName { get; set; } = "Stable Diffusion WebUI"; public override string DisplayName { get; set; } = "Stable Diffusion WebUI";
public override string Author => "AUTOMATIC1111"; public override string Author => "AUTOMATIC1111";
public override string LicenseType => "AGPL-3.0"; public override string LicenseType => "AGPL-3.0";
public override string LicenseUrl => public override string LicenseUrl =>
"https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt"; "https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt";
public override string Blurb => public override string Blurb =>
"A browser interface based on Gradio library for Stable Diffusion"; "A browser interface based on Gradio library for Stable Diffusion";
@ -28,114 +28,112 @@ public class A3WebUI : BaseGitPackage
new("https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png"); new("https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png");
public string RelativeArgsDefinitionScriptPath => "modules.cmd_args"; public string RelativeArgsDefinitionScriptPath => "modules.cmd_args";
public override SharedFolderMethod RecommendedSharedFolderMethod => public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
SharedFolderMethod.Symlink;
public A3WebUI(IGithubApiCache githubApi, ISettingsManager settingsManager, IDownloadService downloadService, public A3WebUI(
IPrerequisiteHelper prerequisiteHelper) : IGithubApiCache githubApi,
base(githubApi, settingsManager, downloadService, prerequisiteHelper) ISettingsManager settingsManager,
{ IDownloadService downloadService,
} IPrerequisiteHelper prerequisiteHelper
)
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
// From https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/master/models // From https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/master/models
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders => new() public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
{
[SharedFolderType.StableDiffusion] = new[] {"models/Stable-diffusion"},
[SharedFolderType.ESRGAN] = new[] {"models/ESRGAN"},
[SharedFolderType.RealESRGAN] = new[] {"models/RealESRGAN"},
[SharedFolderType.SwinIR] = new[] {"models/SwinIR"},
[SharedFolderType.Lora] = new[] {"models/Lora"},
[SharedFolderType.LyCORIS] = new[] {"models/LyCORIS"},
[SharedFolderType.ApproxVAE] = new[] {"models/VAE-approx"},
[SharedFolderType.VAE] = new[] {"models/VAE"},
[SharedFolderType.DeepDanbooru] = new[] {"models/deepbooru"},
[SharedFolderType.Karlo] = new[] {"models/karlo"},
[SharedFolderType.TextualInversion] = new[] {"embeddings"},
[SharedFolderType.Hypernetwork] = new[] {"models/hypernetworks"},
[SharedFolderType.ControlNet] = new[] {"models/ControlNet"}
};
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List<LaunchOptionDefinition> LaunchOptions => new()
{
new()
{
Name = "Host",
Type = LaunchOptionType.String,
DefaultValue = "localhost",
Options = new() {"--host"}
},
new() new()
{ {
Name = "Port", [SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" },
Type = LaunchOptionType.String, [SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" },
DefaultValue = "7860", [SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" },
Options = new() {"--port"} [SharedFolderType.SwinIR] = new[] { "models/SwinIR" },
}, [SharedFolderType.Lora] = new[] { "models/Lora" },
[SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" },
[SharedFolderType.ApproxVAE] = new[] { "models/VAE-approx" },
[SharedFolderType.VAE] = new[] { "models/VAE" },
[SharedFolderType.DeepDanbooru] = new[] { "models/deepbooru" },
[SharedFolderType.Karlo] = new[] { "models/karlo" },
[SharedFolderType.TextualInversion] = new[] { "embeddings" },
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
[SharedFolderType.ControlNet] = new[] { "models/ControlNet" }
};
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List<LaunchOptionDefinition> LaunchOptions =>
new() new()
{ {
Name = "VRAM", new()
Type = LaunchOptionType.Bool,
InitialValue = HardwareHelper.IterGpuInfo().Select(gpu => gpu.MemoryLevel).Max() switch
{ {
Level.Low => "--lowvram", Name = "Host",
Level.Medium => "--medvram", Type = LaunchOptionType.String,
_ => null DefaultValue = "localhost",
Options = new() { "--host" }
}, },
Options = new() { "--lowvram", "--medvram", "--medvram-sdxl" } new()
}, {
new() Name = "Port",
{ Type = LaunchOptionType.String,
Name = "Xformers", DefaultValue = "7860",
Type = LaunchOptionType.Bool, Options = new() { "--port" }
InitialValue = HardwareHelper.HasNvidiaGpu(), },
Options = new() { "--xformers" } new()
}, {
new() Name = "VRAM",
{ Type = LaunchOptionType.Bool,
Name = "API", InitialValue = HardwareHelper
Type = LaunchOptionType.Bool, .IterGpuInfo()
InitialValue = true, .Select(gpu => gpu.MemoryLevel)
Options = new() {"--api"} .Max() switch
}, {
new() Level.Low => "--lowvram",
{ Level.Medium => "--medvram",
Name = "Skip Torch CUDA Check", _ => null
Type = LaunchOptionType.Bool, },
InitialValue = !HardwareHelper.HasNvidiaGpu(), Options = new() { "--lowvram", "--medvram", "--medvram-sdxl" }
Options = new() {"--skip-torch-cuda-test"} },
}, new()
new() {
{ Name = "Xformers",
Name = "Skip Python Version Check", Type = LaunchOptionType.Bool,
Type = LaunchOptionType.Bool, InitialValue = HardwareHelper.HasNvidiaGpu(),
InitialValue = true, Options = new() { "--xformers" }
Options = new() {"--skip-python-version-check"} },
}, new()
new() {
{ Name = "API",
Name = "No Half", Type = LaunchOptionType.Bool,
Type = LaunchOptionType.Bool, InitialValue = true,
Description = "Do not switch the model to 16-bit floats", Options = new() { "--api" }
InitialValue = HardwareHelper.HasAmdGpu(), },
Options = new() {"--no-half"} new()
}, {
LaunchOptionDefinition.Extras Name = "Skip Torch CUDA Check",
}; Type = LaunchOptionType.Bool,
InitialValue = !HardwareHelper.HasNvidiaGpu(),
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => new[] Options = new() { "--skip-torch-cuda-test" }
{ },
SharedFolderMethod.Symlink, new()
SharedFolderMethod.None {
}; Name = "Skip Python Version Check",
Type = LaunchOptionType.Bool,
InitialValue = true,
Options = new() { "--skip-python-version-check" }
},
new()
{
Name = "No Half",
Type = LaunchOptionType.Bool,
Description = "Do not switch the model to 16-bit floats",
InitialValue = HardwareHelper.HasAmdGpu(),
Options = new() { "--no-half" }
},
LaunchOptionDefinition.Extras
};
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
{ new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
TorchVersion.Cpu,
TorchVersion.Cuda, public override IEnumerable<TorchVersion> AvailableTorchVersions =>
TorchVersion.DirectMl, new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm };
TorchVersion.Rocm
};
public override async Task<string> GetLatestVersion() public override async Task<string> GetLatestVersion()
{ {
@ -143,8 +141,11 @@ public class A3WebUI : BaseGitPackage
return release.TagName!; return release.TagName!;
} }
public override async Task InstallPackage(string installLocation, public override async Task InstallPackage(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
@ -173,14 +174,17 @@ public class A3WebUI : BaseGitPackage
} }
// Install requirements file // Install requirements file
progress?.Report(new ProgressReport(-1f, "Installing Package Requirements", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing Package Requirements", isIndeterminate: true)
);
Logger.Info("Installing requirements_versions.txt"); Logger.Info("Installing requirements_versions.txt");
await venvRunner.PipInstall($"-r requirements_versions.txt", OnConsoleOutput) await venvRunner
.PipInstall($"-r requirements_versions.txt", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
progress?.Report(new ProgressReport(1f, "Installing Package Requirements", progress?.Report(
isIndeterminate: false)); new ProgressReport(1f, "Installing Package Requirements", isIndeterminate: false)
);
progress?.Report(new ProgressReport(-1f, "Updating configuration", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Updating configuration", isIndeterminate: true));
@ -189,14 +193,18 @@ public class A3WebUI : BaseGitPackage
var configPath = Path.Combine(installLocation, "config.json"); var configPath = Path.Combine(installLocation, "config.json");
if (!File.Exists(configPath)) if (!File.Exists(configPath))
{ {
var config = new JsonObject {{"show_progress_type", "TAESD"}}; var config = new JsonObject { { "show_progress_type", "TAESD" } };
await File.WriteAllTextAsync(configPath, config.ToString()).ConfigureAwait(false); await File.WriteAllTextAsync(configPath, config.ToString()).ConfigureAwait(false);
} }
progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false)); progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false));
} }
public override async Task RunPackage(string installedPackagePath, string command, string arguments) public override async Task RunPackage(
string installedPackagePath,
string command,
string arguments
)
{ {
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
@ -204,14 +212,14 @@ public class A3WebUI : BaseGitPackage
{ {
OnConsoleOutput(s); OnConsoleOutput(s);
if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase)) if (!s.Text.Contains("Running on", StringComparison.OrdinalIgnoreCase))
return; return;
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
var match = regex.Match(s.Text); var match = regex.Match(s.Text);
if (!match.Success) if (!match.Success)
return; return;
WebUrl = match.Value; WebUrl = match.Value;
OnStartupComplete(WebUrl); OnStartupComplete(WebUrl);
} }
@ -221,16 +229,19 @@ public class A3WebUI : BaseGitPackage
VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
} }
private async Task InstallRocmTorch(PyVenvRunner venvRunner, private async Task InstallRocmTorch(
IProgress<ProgressReport>? progress = null) PyVenvRunner venvRunner,
IProgress<ProgressReport>? progress = null
)
{ {
progress?.Report(new ProgressReport(-1f, "Installing PyTorch for ROCm", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing PyTorch for ROCm", isIndeterminate: true)
);
await venvRunner.PipInstall("--upgrade pip wheel", OnConsoleOutput) await venvRunner.PipInstall("--upgrade pip wheel", OnConsoleOutput).ConfigureAwait(false);
.ConfigureAwait(false);
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, OnConsoleOutput) await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
} }

194
StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs

@ -21,13 +21,13 @@ namespace StabilityMatrix.Core.Models.Packages;
public abstract class BaseGitPackage : BasePackage public abstract class BaseGitPackage : BasePackage
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
protected readonly IGithubApiCache GithubApi; protected readonly IGithubApiCache GithubApi;
protected readonly ISettingsManager SettingsManager; protected readonly ISettingsManager SettingsManager;
protected readonly IDownloadService DownloadService; protected readonly IDownloadService DownloadService;
protected readonly IPrerequisiteHelper PrerequisiteHelper; protected readonly IPrerequisiteHelper PrerequisiteHelper;
public PyVenvRunner? VenvRunner; public PyVenvRunner? VenvRunner;
/// <summary> /// <summary>
/// URL of the hosted web page on launch /// URL of the hosted web page on launch
/// </summary> /// </summary>
@ -47,21 +47,23 @@ public abstract class BaseGitPackage : BasePackage
if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag)) if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag))
{ {
return return $"https://api.github.com/repos/{Author}/{Name}/zipball/{versionOptions.VersionTag}";
$"https://api.github.com/repos/{Author}/{Name}/zipball/{versionOptions.VersionTag}";
} }
if (!string.IsNullOrWhiteSpace(versionOptions.BranchName)) if (!string.IsNullOrWhiteSpace(versionOptions.BranchName))
{ {
return return $"https://api.github.com/repos/{Author}/{Name}/zipball/{versionOptions.BranchName}";
$"https://api.github.com/repos/{Author}/{Name}/zipball/{versionOptions.BranchName}";
} }
throw new Exception("No download URL available"); throw new Exception("No download URL available");
} }
protected BaseGitPackage(IGithubApiCache githubApi, ISettingsManager settingsManager, protected BaseGitPackage(
IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper) IGithubApiCache githubApi,
ISettingsManager settingsManager,
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper
)
{ {
GithubApi = githubApi; GithubApi = githubApi;
SettingsManager = settingsManager; SettingsManager = settingsManager;
@ -71,51 +73,53 @@ public abstract class BaseGitPackage : BasePackage
protected async Task<Release> GetLatestRelease(bool includePrerelease = false) protected async Task<Release> GetLatestRelease(bool includePrerelease = false)
{ {
var releases = await GithubApi var releases = await GithubApi.GetAllReleases(Author, Name).ConfigureAwait(false);
.GetAllReleases(Author, Name)
.ConfigureAwait(false);
return includePrerelease ? releases.First() : releases.First(x => !x.Prerelease); return includePrerelease ? releases.First() : releases.First(x => !x.Prerelease);
} }
public override Task<IEnumerable<Branch>> GetAllBranches() public override Task<IEnumerable<Branch>> GetAllBranches()
{ {
return GithubApi.GetAllBranches(Author, Name); return GithubApi.GetAllBranches(Author, Name);
} }
public override Task<IEnumerable<Release>> GetAllReleases() public override Task<IEnumerable<Release>> GetAllReleases()
{ {
return GithubApi.GetAllReleases(Author, Name); return GithubApi.GetAllReleases(Author, Name);
} }
public override Task<IEnumerable<GitCommit>?> GetAllCommits(string branch, int page = 1, int perPage = 10) public override Task<IEnumerable<GitCommit>?> GetAllCommits(
string branch,
int page = 1,
int perPage = 10
)
{ {
return GithubApi.GetAllCommits(Author, Name, branch, page, perPage); return GithubApi.GetAllCommits(Author, Name, branch, page, perPage);
} }
public override async Task<PackageVersionOptions> GetAllVersionOptions() public override async Task<PackageVersionOptions> GetAllVersionOptions()
{ {
var packageVersionOptions = new PackageVersionOptions(); var packageVersionOptions = new PackageVersionOptions();
var allReleases = await GetAllReleases().ConfigureAwait(false); var allReleases = await GetAllReleases().ConfigureAwait(false);
var releasesList = allReleases.ToList(); var releasesList = allReleases.ToList();
if (releasesList.Any()) if (releasesList.Any())
{ {
packageVersionOptions.AvailableVersions = releasesList.Select(r => packageVersionOptions.AvailableVersions = releasesList.Select(
new PackageVersion r =>
{ new PackageVersion
TagName = r.TagName!, {
ReleaseNotesMarkdown = r.Body, TagName = r.TagName!,
IsPrerelease = r.Prerelease ReleaseNotesMarkdown = r.Body,
}); IsPrerelease = r.Prerelease
}
);
} }
// Branch mode // Branch mode
var allBranches = await GetAllBranches().ConfigureAwait(false); var allBranches = await GetAllBranches().ConfigureAwait(false);
packageVersionOptions.AvailableBranches = allBranches.Select(b => new PackageVersion packageVersionOptions.AvailableBranches = allBranches.Select(
{ b => new PackageVersion { TagName = $"{b.Name}", ReleaseNotesMarkdown = string.Empty }
TagName = $"{b.Name}", );
ReleaseNotesMarkdown = string.Empty
});
return packageVersionOptions; return packageVersionOptions;
} }
@ -125,9 +129,10 @@ public abstract class BaseGitPackage : BasePackage
/// </summary> /// </summary>
[MemberNotNull(nameof(VenvRunner))] [MemberNotNull(nameof(VenvRunner))]
public async Task<PyVenvRunner> SetupVenv( public async Task<PyVenvRunner> SetupVenv(
string installedPackagePath, string installedPackagePath,
string venvName = "venv", string venvName = "venv",
bool forceRecreate = false) bool forceRecreate = false
)
{ {
var venvPath = Path.Combine(installedPackagePath, venvName); var venvPath = Path.Combine(installedPackagePath, venvName);
if (VenvRunner != null) if (VenvRunner != null)
@ -140,24 +145,25 @@ public abstract class BaseGitPackage : BasePackage
WorkingDirectory = installedPackagePath, WorkingDirectory = installedPackagePath,
EnvironmentVariables = SettingsManager.Settings.EnvironmentVariables, EnvironmentVariables = SettingsManager.Settings.EnvironmentVariables,
}; };
if (!VenvRunner.Exists() || forceRecreate) if (!VenvRunner.Exists() || forceRecreate)
{ {
await VenvRunner.Setup(forceRecreate).ConfigureAwait(false); await VenvRunner.Setup(forceRecreate).ConfigureAwait(false);
} }
return VenvRunner; return VenvRunner;
} }
public override async Task<IEnumerable<Release>> GetReleaseTags() public override async Task<IEnumerable<Release>> GetReleaseTags()
{ {
var allReleases = await GithubApi var allReleases = await GithubApi.GetAllReleases(Author, Name).ConfigureAwait(false);
.GetAllReleases(Author, Name)
.ConfigureAwait(false);
return allReleases; return allReleases;
} }
public override async Task DownloadPackage(string installLocation, public override async Task DownloadPackage(
DownloadPackageVersionOptions versionOptions, IProgress<ProgressReport>? progress = null) string installLocation,
DownloadPackageVersionOptions versionOptions,
IProgress<ProgressReport>? progress = null
)
{ {
var downloadUrl = GetDownloadUrl(versionOptions); var downloadUrl = GetDownloadUrl(versionOptions);
@ -173,8 +179,11 @@ public abstract class BaseGitPackage : BasePackage
progress?.Report(new ProgressReport(100, message: "Download Complete")); progress?.Report(new ProgressReport(100, message: "Download Complete"));
} }
public override async Task InstallPackage(string installLocation, public override async Task InstallPackage(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
await UnzipPackage(installLocation, progress).ConfigureAwait(false); await UnzipPackage(installLocation, progress).ConfigureAwait(false);
File.Delete(DownloadLocation); File.Delete(DownloadLocation);
@ -186,7 +195,7 @@ public abstract class BaseGitPackage : BasePackage
var zipDirName = string.Empty; var zipDirName = string.Empty;
var totalEntries = zip.Entries.Count; var totalEntries = zip.Entries.Count;
var currentEntry = 0; var currentEntry = 0;
foreach (var entry in zip.Entries) foreach (var entry in zip.Entries)
{ {
currentEntry++; currentEntry++;
@ -196,20 +205,26 @@ public abstract class BaseGitPackage : BasePackage
{ {
zipDirName = entry.FullName; zipDirName = entry.FullName;
} }
var folderPath = Path.Combine(installLocation, var folderPath = Path.Combine(
entry.FullName.Replace(zipDirName, string.Empty)); installLocation,
entry.FullName.Replace(zipDirName, string.Empty)
);
Directory.CreateDirectory(folderPath); Directory.CreateDirectory(folderPath);
continue; continue;
} }
var destinationPath = Path.GetFullPath(
var destinationPath = Path.GetFullPath(Path.Combine(installLocation, Path.Combine(installLocation, entry.FullName.Replace(zipDirName, string.Empty))
entry.FullName.Replace(zipDirName, string.Empty))); );
entry.ExtractToFile(destinationPath, true); entry.ExtractToFile(destinationPath, true);
progress?.Report(new ProgressReport(current: Convert.ToUInt64(currentEntry), progress?.Report(
total: Convert.ToUInt64(totalEntries))); new ProgressReport(
current: Convert.ToUInt64(currentEntry),
total: Convert.ToUInt64(totalEntries)
)
);
} }
return Task.CompletedTask; return Task.CompletedTask;
@ -232,8 +247,9 @@ public abstract class BaseGitPackage : BasePackage
return UpdateAvailable; return UpdateAvailable;
} }
var allCommits = (await GetAllCommits(currentVersion.InstalledBranch) var allCommits = (
.ConfigureAwait(false))?.ToList(); await GetAllCommits(currentVersion.InstalledBranch).ConfigureAwait(false)
)?.ToList();
if (allCommits == null || !allCommits.Any()) if (allCommits == null || !allCommits.Any())
{ {
Logger.Warn("No commits found for {Package}", package.PackageName); Logger.Warn("No commits found for {Package}", package.PackageName);
@ -241,7 +257,6 @@ public abstract class BaseGitPackage : BasePackage
} }
var latestCommitHash = allCommits.First().Sha; var latestCommitHash = allCommits.First().Sha;
return latestCommitHash != currentVersion.InstalledCommitSha; return latestCommitHash != currentVersion.InstalledCommitSha;
} }
catch (ApiException e) catch (ApiException e)
{ {
@ -250,34 +265,37 @@ public abstract class BaseGitPackage : BasePackage
} }
} }
public override async Task<InstalledPackageVersion> Update(InstalledPackage installedPackage, public override async Task<InstalledPackageVersion> Update(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null, InstalledPackage installedPackage,
bool includePrerelease = false) TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null,
bool includePrerelease = false
)
{ {
if (installedPackage.Version == null) throw new NullReferenceException("Version is null"); if (installedPackage.Version == null)
throw new NullReferenceException("Version is null");
if (installedPackage.Version.IsReleaseMode) if (installedPackage.Version.IsReleaseMode)
{ {
var releases = await GetAllReleases().ConfigureAwait(false); var releases = await GetAllReleases().ConfigureAwait(false);
var latestRelease = releases.First(x => includePrerelease || !x.Prerelease); var latestRelease = releases.First(x => includePrerelease || !x.Prerelease);
await DownloadPackage(installedPackage.FullPath, await DownloadPackage(
new DownloadPackageVersionOptions {VersionTag = latestRelease.TagName}, installedPackage.FullPath,
progress) new DownloadPackageVersionOptions { VersionTag = latestRelease.TagName },
progress
)
.ConfigureAwait(false); .ConfigureAwait(false);
await InstallPackage(installedPackage.FullPath, torchVersion, progress) await InstallPackage(installedPackage.FullPath, torchVersion, progress)
.ConfigureAwait(false); .ConfigureAwait(false);
return new InstalledPackageVersion return new InstalledPackageVersion { InstalledReleaseVersion = latestRelease.TagName };
{
InstalledReleaseVersion = latestRelease.TagName
};
} }
// Commit mode // Commit mode
var allCommits = await GetAllCommits( var allCommits = await GetAllCommits(installedPackage.Version.InstalledBranch)
installedPackage.Version.InstalledBranch).ConfigureAwait(false); .ConfigureAwait(false);
var latestCommit = allCommits?.First(); var latestCommit = allCommits?.First();
if (latestCommit is null || string.IsNullOrEmpty(latestCommit.Sha)) if (latestCommit is null || string.IsNullOrEmpty(latestCommit.Sha))
@ -285,8 +303,11 @@ public abstract class BaseGitPackage : BasePackage
throw new Exception("No commits found for branch"); throw new Exception("No commits found for branch");
} }
await DownloadPackage(installedPackage.FullPath, await DownloadPackage(
new DownloadPackageVersionOptions {CommitHash = latestCommit.Sha}, progress) installedPackage.FullPath,
new DownloadPackageVersionOptions { CommitHash = latestCommit.Sha },
progress
)
.ConfigureAwait(false); .ConfigureAwait(false);
await InstallPackage(installedPackage.FullPath, torchVersion, progress) await InstallPackage(installedPackage.FullPath, torchVersion, progress)
.ConfigureAwait(false); .ConfigureAwait(false);
@ -298,29 +319,40 @@ public abstract class BaseGitPackage : BasePackage
}; };
} }
public override Task SetupModelFolders(DirectoryPath installDirectory, public override Task SetupModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
if (sharedFolderMethod == SharedFolderMethod.Symlink && SharedFolders is { } folders) if (sharedFolderMethod == SharedFolderMethod.Symlink && SharedFolders is { } folders)
{ {
StabilityMatrix.Core.Helper.SharedFolders StabilityMatrix.Core.Helper.SharedFolders.SetupLinks(
.SetupLinks(folders, SettingsManager.ModelsDirectory, installDirectory); folders,
SettingsManager.ModelsDirectory,
installDirectory
);
} }
return Task.CompletedTask; return Task.CompletedTask;
} }
public override async Task UpdateModelFolders(DirectoryPath installDirectory, public override async Task UpdateModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink) if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink)
{ {
await StabilityMatrix.Core.Helper.SharedFolders.UpdateLinksForPackage(this, await StabilityMatrix.Core.Helper.SharedFolders
SettingsManager.ModelsDirectory, installDirectory).ConfigureAwait(false); .UpdateLinksForPackage(this, SettingsManager.ModelsDirectory, installDirectory)
.ConfigureAwait(false);
} }
} }
public override Task RemoveModelFolderLinks(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod) public override Task RemoveModelFolderLinks(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink) if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink)
{ {
@ -340,7 +372,7 @@ public abstract class BaseGitPackage : BasePackage
} }
process.StandardInput.WriteLine(input); process.StandardInput.WriteLine(input);
} }
public virtual async Task SendInputAsync(string input) public virtual async Task SendInputAsync(string input)
{ {
var process = VenvRunner?.Process; var process = VenvRunner?.Process;
@ -361,7 +393,7 @@ public abstract class BaseGitPackage : BasePackage
VenvRunner = null; VenvRunner = null;
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override async Task WaitForShutdown() public override async Task WaitForShutdown()
{ {

148
StabilityMatrix.Core/Models/Packages/BasePackage.cs

@ -21,52 +21,68 @@ public abstract class BasePackage
public abstract string LicenseUrl { get; } public abstract string LicenseUrl { get; }
public virtual string Disclaimer => string.Empty; public virtual string Disclaimer => string.Empty;
public virtual bool OfferInOneClickInstaller => true; public virtual bool OfferInOneClickInstaller => true;
/// <summary> /// <summary>
/// Primary command to launch the package. 'Launch' buttons uses this. /// Primary command to launch the package. 'Launch' buttons uses this.
/// </summary> /// </summary>
public abstract string LaunchCommand { get; } public abstract string LaunchCommand { get; }
/// <summary> /// <summary>
/// Optional commands (e.g. 'config') that are on the launch button split drop-down. /// Optional commands (e.g. 'config') that are on the launch button split drop-down.
/// </summary> /// </summary>
public virtual IReadOnlyList<string> ExtraLaunchCommands { get; } = Array.Empty<string>(); public virtual IReadOnlyList<string> ExtraLaunchCommands { get; } = Array.Empty<string>();
public abstract Uri PreviewImageUri { get; } public abstract Uri PreviewImageUri { get; }
public virtual bool ShouldIgnoreReleases => false; public virtual bool ShouldIgnoreReleases => false;
public virtual bool UpdateAvailable { get; set; } public virtual bool UpdateAvailable { get; set; }
public abstract Task DownloadPackage(string installLocation, DownloadPackageVersionOptions versionOptions, public abstract Task DownloadPackage(
IProgress<ProgressReport>? progress1); string installLocation,
DownloadPackageVersionOptions versionOptions,
IProgress<ProgressReport>? progress1
);
public abstract Task InstallPackage(
string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
);
public abstract Task InstallPackage(string installLocation, TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null);
public abstract Task RunPackage(string installedPackagePath, string command, string arguments); public abstract Task RunPackage(string installedPackagePath, string command, string arguments);
public abstract Task<bool> CheckForUpdates(InstalledPackage package); public abstract Task<bool> CheckForUpdates(InstalledPackage package);
public abstract Task<InstalledPackageVersion> Update(InstalledPackage installedPackage, public abstract Task<InstalledPackageVersion> Update(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null, InstalledPackage installedPackage,
bool includePrerelease = false); TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null,
public virtual IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => new[] bool includePrerelease = false
{ );
SharedFolderMethod.Symlink,
SharedFolderMethod.Configuration, public virtual IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
SharedFolderMethod.None new[]
}; {
SharedFolderMethod.Symlink,
SharedFolderMethod.Configuration,
SharedFolderMethod.None
};
public abstract SharedFolderMethod RecommendedSharedFolderMethod { get; } public abstract SharedFolderMethod RecommendedSharedFolderMethod { get; }
public abstract Task SetupModelFolders(DirectoryPath installDirectory, public abstract Task SetupModelFolders(
SharedFolderMethod sharedFolderMethod); DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
);
public abstract Task UpdateModelFolders(DirectoryPath installDirectory, public abstract Task UpdateModelFolders(
SharedFolderMethod sharedFolderMethod); DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
);
public abstract Task RemoveModelFolderLinks(DirectoryPath installDirectory, public abstract Task RemoveModelFolderLinks(
SharedFolderMethod sharedFolderMethod); DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
);
public abstract IEnumerable<TorchVersion> AvailableTorchVersions { get; } public abstract IEnumerable<TorchVersion> AvailableTorchVersions { get; }
@ -77,7 +93,7 @@ public abstract class BasePackage
{ {
return AvailableTorchVersions.First(); return AvailableTorchVersions.First();
} }
if (HardwareHelper.HasNvidiaGpu() && AvailableTorchVersions.Contains(TorchVersion.Cuda)) if (HardwareHelper.HasNvidiaGpu() && AvailableTorchVersions.Contains(TorchVersion.Cuda))
{ {
return TorchVersion.Cuda; return TorchVersion.Cuda;
@ -88,15 +104,17 @@ public abstract class BasePackage
return TorchVersion.Rocm; return TorchVersion.Rocm;
} }
if (HardwareHelper.PreferDirectML() && if (
AvailableTorchVersions.Contains(TorchVersion.DirectMl)) HardwareHelper.PreferDirectML()
&& AvailableTorchVersions.Contains(TorchVersion.DirectMl)
)
{ {
return TorchVersion.DirectMl; return TorchVersion.DirectMl;
} }
return TorchVersion.Cpu; return TorchVersion.Cpu;
} }
/// <summary> /// <summary>
/// Shuts down the subprocess, canceling any pending streams. /// Shuts down the subprocess, canceling any pending streams.
/// </summary> /// </summary>
@ -110,16 +128,20 @@ public abstract class BasePackage
public abstract List<LaunchOptionDefinition> LaunchOptions { get; } public abstract List<LaunchOptionDefinition> LaunchOptions { get; }
public virtual string? ExtraLaunchArguments { get; set; } = null; public virtual string? ExtraLaunchArguments { get; set; } = null;
/// <summary> /// <summary>
/// The shared folders that this package supports. /// The shared folders that this package supports.
/// Mapping of <see cref="SharedFolderType"/> to the relative paths from the package root. /// Mapping of <see cref="SharedFolderType"/> to the relative paths from the package root.
/// </summary> /// </summary>
public virtual Dictionary<SharedFolderType, IReadOnlyList<string>>? SharedFolders { get; } public virtual Dictionary<SharedFolderType, IReadOnlyList<string>>? SharedFolders { get; }
public abstract Task<string> GetLatestVersion(); public abstract Task<string> GetLatestVersion();
public abstract Task<PackageVersionOptions> GetAllVersionOptions(); public abstract Task<PackageVersionOptions> GetAllVersionOptions();
public abstract Task<IEnumerable<GitCommit>?> GetAllCommits(string branch, int page = 1, int perPage = 10); public abstract Task<IEnumerable<GitCommit>?> GetAllCommits(
string branch,
int page = 1,
int perPage = 10
);
public abstract Task<IEnumerable<Branch>> GetAllBranches(); public abstract Task<IEnumerable<Branch>> GetAllBranches();
public abstract Task<IEnumerable<Release>> GetAllReleases(); public abstract Task<IEnumerable<Release>> GetAllReleases();
public event EventHandler<ProcessOutput>? ConsoleOutput; public event EventHandler<ProcessOutput>? ConsoleOutput;
@ -127,44 +149,56 @@ public abstract class BasePackage
public event EventHandler<string>? StartupComplete; public event EventHandler<string>? StartupComplete;
public void OnConsoleOutput(ProcessOutput output) => ConsoleOutput?.Invoke(this, output); public void OnConsoleOutput(ProcessOutput output) => ConsoleOutput?.Invoke(this, output);
public void OnExit(int exitCode) => Exited?.Invoke(this, exitCode); public void OnExit(int exitCode) => Exited?.Invoke(this, exitCode);
public void OnStartupComplete(string url) => StartupComplete?.Invoke(this, url); public void OnStartupComplete(string url) => StartupComplete?.Invoke(this, url);
public virtual PackageVersionType AvailableVersionTypes => ShouldIgnoreReleases public virtual PackageVersionType AvailableVersionTypes =>
? PackageVersionType.Commit ShouldIgnoreReleases
: PackageVersionType.GithubRelease | PackageVersionType.Commit; ? PackageVersionType.Commit
: PackageVersionType.GithubRelease | PackageVersionType.Commit;
protected async Task InstallCudaTorch(
PyVenvRunner venvRunner,
IProgress<ProgressReport>? progress = null
protected async Task InstallCudaTorch(PyVenvRunner venvRunner, )
IProgress<ProgressReport>? progress = null)
{ {
progress?.Report(new ProgressReport(-1f, "Installing PyTorch for CUDA", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing PyTorch for CUDA", isIndeterminate: true)
);
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, OnConsoleOutput) await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
await venvRunner.PipInstall("xformers", OnConsoleOutput).ConfigureAwait(false); await venvRunner.PipInstall("xformers", OnConsoleOutput).ConfigureAwait(false);
} }
protected async Task InstallDirectMlTorch(PyVenvRunner venvRunner, protected async Task InstallDirectMlTorch(
IProgress<ProgressReport>? progress = null) PyVenvRunner venvRunner,
IProgress<ProgressReport>? progress = null
)
{ {
progress?.Report(new ProgressReport(-1f, "Installing PyTorch for DirectML", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true)
);
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, OnConsoleOutput) await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
protected async Task InstallCpuTorch(PyVenvRunner venvRunner, IProgress<ProgressReport>? progress = null) protected async Task InstallCpuTorch(
PyVenvRunner venvRunner,
IProgress<ProgressReport>? progress = null
)
{ {
progress?.Report(new ProgressReport(-1f, "Installing PyTorch for CPU", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true)
);
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, OnConsoleOutput) await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
} }

227
StabilityMatrix.Core/Models/Packages/ComfyUI.cs

@ -21,7 +21,7 @@ public class ComfyUI : BaseGitPackage
public override string DisplayName { get; set; } = "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI";
public override string Author => "comfyanonymous"; public override string Author => "comfyanonymous";
public override string LicenseType => "GPL-3.0"; public override string LicenseType => "GPL-3.0";
public override string LicenseUrl => public override string LicenseUrl =>
"https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE"; "https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE";
public override string Blurb => "A powerful and modular stable diffusion GUI and backend"; public override string Blurb => "A powerful and modular stable diffusion GUI and backend";
public override string LaunchCommand => "main.py"; public override string LaunchCommand => "main.py";
@ -32,77 +32,82 @@ public class ComfyUI : BaseGitPackage
public override SharedFolderMethod RecommendedSharedFolderMethod => public override SharedFolderMethod RecommendedSharedFolderMethod =>
SharedFolderMethod.Configuration; SharedFolderMethod.Configuration;
public ComfyUI(IGithubApiCache githubApi, ISettingsManager settingsManager, IDownloadService downloadService, public ComfyUI(
IPrerequisiteHelper prerequisiteHelper) : IGithubApiCache githubApi,
base(githubApi, settingsManager, downloadService, prerequisiteHelper) ISettingsManager settingsManager,
{ IDownloadService downloadService,
} IPrerequisiteHelper prerequisiteHelper
)
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
// https://github.com/comfyanonymous/ComfyUI/blob/master/folder_paths.py#L11 // https://github.com/comfyanonymous/ComfyUI/blob/master/folder_paths.py#L11
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders => new() public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
{
[SharedFolderType.StableDiffusion] = new[] {"models/checkpoints"},
[SharedFolderType.Diffusers] = new[] {"models/diffusers"},
[SharedFolderType.Lora] = new[] {"models/loras"},
[SharedFolderType.CLIP] = new[] {"models/clip"},
[SharedFolderType.TextualInversion] = new[] {"models/embeddings"},
[SharedFolderType.VAE] = new[] {"models/vae"},
[SharedFolderType.ApproxVAE] = new[] {"models/vae_approx"},
[SharedFolderType.ControlNet] = new[] {"models/controlnet"},
[SharedFolderType.GLIGEN] = new[] {"models/gligen"},
[SharedFolderType.ESRGAN] = new[] {"models/upscale_models"},
[SharedFolderType.Hypernetwork] = new[] {"models/hypernetworks"},
};
public override List<LaunchOptionDefinition> LaunchOptions => new List<LaunchOptionDefinition>
{
new() new()
{ {
Name = "VRAM", [SharedFolderType.StableDiffusion] = new[] { "models/checkpoints" },
Type = LaunchOptionType.Bool, [SharedFolderType.Diffusers] = new[] { "models/diffusers" },
InitialValue = HardwareHelper.IterGpuInfo().Select(gpu => gpu.MemoryLevel).Max() switch [SharedFolderType.Lora] = new[] { "models/loras" },
[SharedFolderType.CLIP] = new[] { "models/clip" },
[SharedFolderType.TextualInversion] = new[] { "models/embeddings" },
[SharedFolderType.VAE] = new[] { "models/vae" },
[SharedFolderType.ApproxVAE] = new[] { "models/vae_approx" },
[SharedFolderType.ControlNet] = new[] { "models/controlnet" },
[SharedFolderType.GLIGEN] = new[] { "models/gligen" },
[SharedFolderType.ESRGAN] = new[] { "models/upscale_models" },
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
};
public override List<LaunchOptionDefinition> LaunchOptions =>
new List<LaunchOptionDefinition>
{
new()
{ {
Level.Low => "--lowvram", Name = "VRAM",
Level.Medium => "--normalvram", Type = LaunchOptionType.Bool,
_ => null InitialValue = HardwareHelper
.IterGpuInfo()
.Select(gpu => gpu.MemoryLevel)
.Max() switch
{
Level.Low => "--lowvram",
Level.Medium => "--normalvram",
_ => null
},
Options = { "--highvram", "--normalvram", "--lowvram", "--novram" }
}, },
Options = { "--highvram", "--normalvram", "--lowvram", "--novram" } new()
}, {
new() Name = "Use CPU only",
{ Type = LaunchOptionType.Bool,
Name = "Use CPU only", InitialValue = !HardwareHelper.HasNvidiaGpu(),
Type = LaunchOptionType.Bool, Options = { "--cpu" }
InitialValue = !HardwareHelper.HasNvidiaGpu(), },
Options = {"--cpu"} new()
}, {
new() Name = "Disable Xformers",
{ Type = LaunchOptionType.Bool,
Name = "Disable Xformers", InitialValue = !HardwareHelper.HasNvidiaGpu(),
Type = LaunchOptionType.Bool, Options = { "--disable-xformers" }
InitialValue = !HardwareHelper.HasNvidiaGpu(), },
Options = { "--disable-xformers" } new()
}, {
new() Name = "Auto-Launch",
{ Type = LaunchOptionType.Bool,
Name = "Auto-Launch", Options = { "--auto-launch" }
Type = LaunchOptionType.Bool, },
Options = { "--auto-launch" } LaunchOptionDefinition.Extras
}, };
LaunchOptionDefinition.Extras
};
public override Task<string> GetLatestVersion() => Task.FromResult("master"); public override Task<string> GetLatestVersion() => Task.FromResult("master");
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] public override IEnumerable<TorchVersion> AvailableTorchVersions =>
{ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.DirectMl, TorchVersion.Rocm };
TorchVersion.Cpu,
TorchVersion.Cuda, public override async Task InstallPackage(
TorchVersion.DirectMl, string installLocation,
TorchVersion.Rocm TorchVersion torchVersion,
}; IProgress<ProgressReport>? progress = null
)
public override async Task InstallPackage(string installLocation,
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null)
{ {
await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
@ -132,17 +137,21 @@ public class ComfyUI : BaseGitPackage
} }
// Install requirements file // Install requirements file
progress?.Report(new ProgressReport(-1, "Installing Package Requirements", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1, "Installing Package Requirements", isIndeterminate: true)
);
Logger.Info("Installing requirements.txt"); Logger.Info("Installing requirements.txt");
await venvRunner.PipInstall($"-r requirements.txt", OnConsoleOutput).ConfigureAwait(false); await venvRunner.PipInstall($"-r requirements.txt", OnConsoleOutput).ConfigureAwait(false);
progress?.Report(new ProgressReport(1, "Installing Package Requirements", progress?.Report(
isIndeterminate: false)); new ProgressReport(1, "Installing Package Requirements", isIndeterminate: false)
);
} }
private async Task AutoDetectAndInstallTorch(PyVenvRunner venvRunner, private async Task AutoDetectAndInstallTorch(
IProgress<ProgressReport>? progress = null) PyVenvRunner venvRunner,
IProgress<ProgressReport>? progress = null
)
{ {
var gpus = HardwareHelper.IterGpuInfo().ToList(); var gpus = HardwareHelper.IterGpuInfo().ToList();
if (gpus.Any(g => g.IsNvidia)) if (gpus.Any(g => g.IsNvidia))
@ -163,14 +172,18 @@ public class ComfyUI : BaseGitPackage
} }
} }
public override async Task RunPackage(string installedPackagePath, string command, string arguments) public override async Task RunPackage(
string installedPackagePath,
string command,
string arguments
)
{ {
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
void HandleConsoleOutput(ProcessOutput s) void HandleConsoleOutput(ProcessOutput s)
{ {
OnConsoleOutput(s); OnConsoleOutput(s);
if (s.Text.Contains("To see the GUI go to", StringComparison.OrdinalIgnoreCase)) if (s.Text.Contains("To see the GUI go to", StringComparison.OrdinalIgnoreCase))
{ {
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
@ -191,14 +204,13 @@ public class ComfyUI : BaseGitPackage
var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
VenvRunner?.RunDetached( VenvRunner?.RunDetached(args.TrimEnd(), HandleConsoleOutput, HandleExit);
args.TrimEnd(),
HandleConsoleOutput,
HandleExit);
} }
public override Task SetupModelFolders(DirectoryPath installDirectory, public override Task SetupModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
switch (sharedFolderMethod) switch (sharedFolderMethod)
{ {
@ -223,19 +235,22 @@ public class ComfyUI : BaseGitPackage
File.WriteAllText(extraPathsYamlPath, string.Empty); File.WriteAllText(extraPathsYamlPath, string.Empty);
} }
var yaml = File.ReadAllText(extraPathsYamlPath); var yaml = File.ReadAllText(extraPathsYamlPath);
var comfyModelPaths = deserializer.Deserialize<ComfyModelPathsYaml>(yaml) ?? var comfyModelPaths =
// ReSharper disable once NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract deserializer.Deserialize<ComfyModelPathsYaml>(yaml)
// cuz it can actually be null lol ??
new ComfyModelPathsYaml(); // ReSharper disable once NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract
// cuz it can actually be null lol
new ComfyModelPathsYaml();
comfyModelPaths.StabilityMatrix ??= new ComfyModelPathsYaml.SmData(); comfyModelPaths.StabilityMatrix ??= new ComfyModelPathsYaml.SmData();
comfyModelPaths.StabilityMatrix.Checkpoints = Path.Combine(modelsDir, "StableDiffusion"); comfyModelPaths.StabilityMatrix.Checkpoints = Path.Combine(modelsDir, "StableDiffusion");
comfyModelPaths.StabilityMatrix.Vae = Path.Combine(modelsDir, "VAE"); comfyModelPaths.StabilityMatrix.Vae = Path.Combine(modelsDir, "VAE");
comfyModelPaths.StabilityMatrix.Loras = $"{Path.Combine(modelsDir, "Lora")}\n" + comfyModelPaths.StabilityMatrix.Loras =
$"{Path.Combine(modelsDir, "LyCORIS")}"; $"{Path.Combine(modelsDir, "Lora")}\n" + $"{Path.Combine(modelsDir, "LyCORIS")}";
comfyModelPaths.StabilityMatrix.UpscaleModels = $"{Path.Combine(modelsDir, "ESRGAN")}\n" + comfyModelPaths.StabilityMatrix.UpscaleModels =
$"{Path.Combine(modelsDir, "RealESRGAN")}\n" + $"{Path.Combine(modelsDir, "ESRGAN")}\n"
$"{Path.Combine(modelsDir, "SwinIR")}"; + $"{Path.Combine(modelsDir, "RealESRGAN")}\n"
+ $"{Path.Combine(modelsDir, "SwinIR")}";
comfyModelPaths.StabilityMatrix.Embeddings = Path.Combine(modelsDir, "TextualInversion"); comfyModelPaths.StabilityMatrix.Embeddings = Path.Combine(modelsDir, "TextualInversion");
comfyModelPaths.StabilityMatrix.Hypernetworks = Path.Combine(modelsDir, "Hypernetwork"); comfyModelPaths.StabilityMatrix.Hypernetworks = Path.Combine(modelsDir, "Hypernetwork");
comfyModelPaths.StabilityMatrix.Controlnet = Path.Combine(modelsDir, "ControlNet"); comfyModelPaths.StabilityMatrix.Controlnet = Path.Combine(modelsDir, "ControlNet");
@ -243,7 +258,7 @@ public class ComfyUI : BaseGitPackage
comfyModelPaths.StabilityMatrix.Diffusers = Path.Combine(modelsDir, "Diffusers"); comfyModelPaths.StabilityMatrix.Diffusers = Path.Combine(modelsDir, "Diffusers");
comfyModelPaths.StabilityMatrix.Gligen = Path.Combine(modelsDir, "GLIGEN"); comfyModelPaths.StabilityMatrix.Gligen = Path.Combine(modelsDir, "GLIGEN");
comfyModelPaths.StabilityMatrix.VaeApprox = Path.Combine(modelsDir, "ApproxVAE"); comfyModelPaths.StabilityMatrix.VaeApprox = Path.Combine(modelsDir, "ApproxVAE");
var serializer = new SerializerBuilder() var serializer = new SerializerBuilder()
.WithNamingConvention(UnderscoredNamingConvention.Instance) .WithNamingConvention(UnderscoredNamingConvention.Instance)
.Build(); .Build();
@ -253,33 +268,39 @@ public class ComfyUI : BaseGitPackage
return Task.CompletedTask; return Task.CompletedTask;
} }
public override Task UpdateModelFolders(DirectoryPath installDirectory, public override Task UpdateModelFolders(
SharedFolderMethod sharedFolderMethod) => DirectoryPath installDirectory,
SetupModelFolders(installDirectory, sharedFolderMethod); SharedFolderMethod sharedFolderMethod
) => SetupModelFolders(installDirectory, sharedFolderMethod);
public override Task RemoveModelFolderLinks(DirectoryPath installDirectory, public override Task RemoveModelFolderLinks(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
return sharedFolderMethod switch return sharedFolderMethod switch
{ {
SharedFolderMethod.Configuration => Task.CompletedTask, SharedFolderMethod.Configuration => Task.CompletedTask,
SharedFolderMethod.None => Task.CompletedTask, SharedFolderMethod.None => Task.CompletedTask,
SharedFolderMethod.Symlink => base.RemoveModelFolderLinks(installDirectory, SharedFolderMethod.Symlink
sharedFolderMethod), => base.RemoveModelFolderLinks(installDirectory, sharedFolderMethod),
_ => Task.CompletedTask _ => Task.CompletedTask
}; };
} }
private async Task InstallRocmTorch(PyVenvRunner venvRunner, private async Task InstallRocmTorch(
IProgress<ProgressReport>? progress = null) PyVenvRunner venvRunner,
IProgress<ProgressReport>? progress = null
)
{ {
progress?.Report(new ProgressReport(-1f, "Installing PyTorch for ROCm", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing PyTorch for ROCm", isIndeterminate: true)
);
await venvRunner.PipInstall("--upgrade pip wheel", OnConsoleOutput) await venvRunner.PipInstall("--upgrade pip wheel", OnConsoleOutput).ConfigureAwait(false);
.ConfigureAwait(false);
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, OnConsoleOutput) await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
} }

45
StabilityMatrix.Core/Models/Packages/DankDiffusion.cs

@ -7,28 +7,26 @@ namespace StabilityMatrix.Core.Models.Packages;
public class DankDiffusion : BaseGitPackage public class DankDiffusion : BaseGitPackage
{ {
public DankDiffusion(IGithubApiCache githubApi, ISettingsManager settingsManager, IDownloadService downloadService, public DankDiffusion(
IPrerequisiteHelper prerequisiteHelper) : IGithubApiCache githubApi,
base(githubApi, settingsManager, downloadService, prerequisiteHelper) ISettingsManager settingsManager,
{ IDownloadService downloadService,
} IPrerequisiteHelper prerequisiteHelper
)
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
public override string Name => "dank-diffusion"; public override string Name => "dank-diffusion";
public override string DisplayName { get; set; } = "Dank Diffusion"; public override string DisplayName { get; set; } = "Dank Diffusion";
public override string Author => "mohnjiles"; public override string Author => "mohnjiles";
public override string LicenseType => "AGPL-3.0"; public override string LicenseType => "AGPL-3.0";
public override string LicenseUrl => public override string LicenseUrl =>
"https://github.com/LykosAI/StabilityMatrix/blob/main/LICENSE"; "https://github.com/LykosAI/StabilityMatrix/blob/main/LICENSE";
public override string Blurb => "A dank interface for diffusion"; public override string Blurb => "A dank interface for diffusion";
public override string LaunchCommand => "test"; public override string LaunchCommand => "test";
public override SharedFolderMethod RecommendedSharedFolderMethod => public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
SharedFolderMethod.Symlink;
public override IReadOnlyList<string> ExtraLaunchCommands => new[] { "test-config", };
public override IReadOnlyList<string> ExtraLaunchCommands => new[]
{
"test-config",
};
public override Uri PreviewImageUri { get; } public override Uri PreviewImageUri { get; }
public override Task RunPackage(string installedPackagePath, string command, string arguments) public override Task RunPackage(string installedPackagePath, string command, string arguments)
@ -36,20 +34,26 @@ public class DankDiffusion : BaseGitPackage
throw new NotImplementedException(); throw new NotImplementedException();
} }
public override Task SetupModelFolders(DirectoryPath installDirectory, public override Task SetupModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
public override Task UpdateModelFolders(DirectoryPath installDirectory, public override Task UpdateModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
public override Task RemoveModelFolderLinks(DirectoryPath installDirectory, public override Task RemoveModelFolderLinks(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
@ -57,6 +61,7 @@ public class DankDiffusion : BaseGitPackage
public override IEnumerable<TorchVersion> AvailableTorchVersions { get; } public override IEnumerable<TorchVersion> AvailableTorchVersions { get; }
public override List<LaunchOptionDefinition> LaunchOptions { get; } public override List<LaunchOptionDefinition> LaunchOptions { get; }
public override Task<string> GetLatestVersion() public override Task<string> GetLatestVersion()
{ {
throw new NotImplementedException(); throw new NotImplementedException();

153
StabilityMatrix.Core/Models/Packages/Fooocus.cs

@ -10,11 +10,13 @@ namespace StabilityMatrix.Core.Models.Packages;
public class Fooocus : BaseGitPackage public class Fooocus : BaseGitPackage
{ {
public Fooocus(IGithubApiCache githubApi, ISettingsManager settingsManager, public Fooocus(
IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper) : base(githubApi, IGithubApiCache githubApi,
settingsManager, downloadService, prerequisiteHelper) ISettingsManager settingsManager,
{ IDownloadService downloadService,
} IPrerequisiteHelper prerequisiteHelper
)
: base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
public override string Name => "Fooocus"; public override string Name => "Fooocus";
public override string DisplayName { get; set; } = "Fooocus"; public override string DisplayName { get; set; } = "Fooocus";
@ -28,76 +30,76 @@ public class Fooocus : BaseGitPackage
public override string LaunchCommand => "launch.py"; public override string LaunchCommand => "launch.py";
public override Uri PreviewImageUri => public override Uri PreviewImageUri =>
new("https://user-images.githubusercontent.com/19834515/261830306-f79c5981-cf80-4ee3-b06b-3fef3f8bfbc7.png"); new(
"https://user-images.githubusercontent.com/19834515/261830306-f79c5981-cf80-4ee3-b06b-3fef3f8bfbc7.png"
);
public override List<LaunchOptionDefinition> LaunchOptions => new() public override List<LaunchOptionDefinition> LaunchOptions =>
{ new()
new LaunchOptionDefinition
{ {
Name = "Port", new LaunchOptionDefinition
Type = LaunchOptionType.String, {
Description = "Sets the listen port", Name = "Port",
Options = {"--port"} Type = LaunchOptionType.String,
}, Description = "Sets the listen port",
new LaunchOptionDefinition Options = { "--port" }
{ },
Name = "Share", new LaunchOptionDefinition
Type = LaunchOptionType.Bool, {
Description = "Set whether to share on Gradio", Name = "Share",
Options = {"--share"} Type = LaunchOptionType.Bool,
}, Description = "Set whether to share on Gradio",
new LaunchOptionDefinition Options = { "--share" }
},
new LaunchOptionDefinition
{
Name = "Listen",
Type = LaunchOptionType.String,
Description = "Set the listen interface",
Options = { "--listen" }
},
LaunchOptionDefinition.Extras
};
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
new()
{ {
Name = "Listen", [SharedFolderType.StableDiffusion] = new[] { "models/checkpoints" },
Type = LaunchOptionType.String, [SharedFolderType.Diffusers] = new[] { "models/diffusers" },
Description = "Set the listen interface", [SharedFolderType.Lora] = new[] { "models/loras" },
Options = {"--listen"} [SharedFolderType.CLIP] = new[] { "models/clip" },
}, [SharedFolderType.TextualInversion] = new[] { "models/embeddings" },
LaunchOptionDefinition.Extras [SharedFolderType.VAE] = new[] { "models/vae" },
}; [SharedFolderType.ApproxVAE] = new[] { "models/vae_approx" },
[SharedFolderType.ControlNet] = new[] { "models/controlnet" },
public override SharedFolderMethod RecommendedSharedFolderMethod => [SharedFolderType.GLIGEN] = new[] { "models/gligen" },
SharedFolderMethod.Symlink; [SharedFolderType.ESRGAN] = new[] { "models/upscale_models" },
[SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" }
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => new[] };
{
SharedFolderMethod.Symlink, public override IEnumerable<TorchVersion> AvailableTorchVersions =>
SharedFolderMethod.None new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm };
};
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders => new()
{
[SharedFolderType.StableDiffusion] = new[] {"models/checkpoints"},
[SharedFolderType.Diffusers] = new[] {"models/diffusers"},
[SharedFolderType.Lora] = new[] {"models/loras"},
[SharedFolderType.CLIP] = new[] {"models/clip"},
[SharedFolderType.TextualInversion] = new[] {"models/embeddings"},
[SharedFolderType.VAE] = new[] {"models/vae"},
[SharedFolderType.ApproxVAE] = new[] {"models/vae_approx"},
[SharedFolderType.ControlNet] = new[] {"models/controlnet"},
[SharedFolderType.GLIGEN] = new[] {"models/gligen"},
[SharedFolderType.ESRGAN] = new[] {"models/upscale_models"},
[SharedFolderType.Hypernetwork] = new[] {"models/hypernetworks"}
};
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[]
{
TorchVersion.Cpu,
TorchVersion.Cuda,
TorchVersion.Rocm
};
public override async Task<string> GetLatestVersion() public override async Task<string> GetLatestVersion()
{ {
var release = await GetLatestRelease().ConfigureAwait(false); var release = await GetLatestRelease().ConfigureAwait(false);
return release.TagName!; return release.TagName!;
} }
public override async Task InstallPackage(string installLocation, public override async Task InstallPackage(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
var venvRunner = await SetupVenv(installLocation, forceRecreate: true).ConfigureAwait(false); var venvRunner = await SetupVenv(installLocation, forceRecreate: true)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
@ -120,22 +122,30 @@ public class Fooocus : BaseGitPackage
await venvRunner await venvRunner
.PipInstall( .PipInstall(
$"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersionStr}", $"torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/{torchVersionStr}",
OnConsoleOutput).ConfigureAwait(false); OnConsoleOutput
)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Installing requirements...", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)
await venvRunner.PipInstall("-r requirements_versions.txt", OnConsoleOutput) );
await venvRunner
.PipInstall("-r requirements_versions.txt", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
public override async Task RunPackage(string installedPackagePath, string command, string arguments) public override async Task RunPackage(
string installedPackagePath,
string command,
string arguments
)
{ {
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
void HandleConsoleOutput(ProcessOutput s) void HandleConsoleOutput(ProcessOutput s)
{ {
OnConsoleOutput(s); OnConsoleOutput(s);
if (s.Text.Contains("Use the app with", StringComparison.OrdinalIgnoreCase)) if (s.Text.Contains("Use the app with", StringComparison.OrdinalIgnoreCase))
{ {
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
@ -156,9 +166,6 @@ public class Fooocus : BaseGitPackage
var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
VenvRunner?.RunDetached( VenvRunner?.RunDetached(args.TrimEnd(), HandleConsoleOutput, HandleExit);
args.TrimEnd(),
HandleConsoleOutput,
HandleExit);
} }
} }

270
StabilityMatrix.Core/Models/Packages/InvokeAI.cs

@ -21,114 +21,112 @@ public class InvokeAI : BaseGitPackage
public override string Author => "invoke-ai"; public override string Author => "invoke-ai";
public override string LicenseType => "Apache-2.0"; public override string LicenseType => "Apache-2.0";
public override string LicenseUrl => public override string LicenseUrl => "https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE";
"https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE";
public override string Blurb => "Professional Creative Tools for Stable Diffusion"; public override string Blurb => "Professional Creative Tools for Stable Diffusion";
public override string LaunchCommand => "invokeai-web"; public override string LaunchCommand => "invokeai-web";
public override IReadOnlyList<string> ExtraLaunchCommands => new[] public override IReadOnlyList<string> ExtraLaunchCommands =>
{ new[]
"invokeai-configure", {
"invokeai-merge", "invokeai-configure",
"invokeai-metadata", "invokeai-merge",
"invokeai-model-install", "invokeai-metadata",
"invokeai-node-cli", "invokeai-model-install",
"invokeai-ti", "invokeai-node-cli",
"invokeai-update", "invokeai-ti",
}; "invokeai-update",
};
public override Uri PreviewImageUri => new(
"https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/docs/assets/canvas_preview.png"); public override Uri PreviewImageUri =>
new(
"https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/docs/assets/canvas_preview.png"
);
public override bool ShouldIgnoreReleases => true; public override bool ShouldIgnoreReleases => true;
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => new[] public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
{ new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
SharedFolderMethod.Symlink, public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
SharedFolderMethod.None
};
public override SharedFolderMethod RecommendedSharedFolderMethod =>
SharedFolderMethod.Symlink;
public InvokeAI( public InvokeAI(
IGithubApiCache githubApi, IGithubApiCache githubApi,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IDownloadService downloadService, IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper) : IPrerequisiteHelper prerequisiteHelper
base(githubApi, settingsManager, downloadService, prerequisiteHelper) )
{ : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
}
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders => new()
{
[SharedFolderType.StableDiffusion] = new[] { RelativeRootPath + "/autoimport/main" },
[SharedFolderType.Lora] = new[] { RelativeRootPath + "/autoimport/lora" },
[SharedFolderType.TextualInversion] = new[] { RelativeRootPath + "/autoimport/embedding" },
[SharedFolderType.ControlNet] = new[] { RelativeRootPath + "/autoimport/controlnet" },
};
// https://github.com/invoke-ai/InvokeAI/blob/main/docs/features/CONFIGURATION.md public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
public override List<LaunchOptionDefinition> LaunchOptions => new List<LaunchOptionDefinition>
{
new() new()
{ {
Name = "Host", [SharedFolderType.StableDiffusion] = new[] { RelativeRootPath + "/autoimport/main" },
Type = LaunchOptionType.String, [SharedFolderType.Lora] = new[] { RelativeRootPath + "/autoimport/lora" },
DefaultValue = "localhost", [SharedFolderType.TextualInversion] = new[]
Options = new List<string> {"--host"}
},
new()
{
Name = "Port",
Type = LaunchOptionType.String,
DefaultValue = "9090",
Options = new List<string> {"--port"}
},
new()
{
Name = "Allow Origins",
Description = "List of host names or IP addresses that are allowed to connect to the " +
"InvokeAI API in the format ['host1','host2',...]",
Type = LaunchOptionType.String,
DefaultValue = "[]",
Options = new List<string> {"--allow-origins"}
},
new()
{
Name = "Always use CPU",
Type = LaunchOptionType.Bool,
Options = new List<string> {"--always_use_cpu"}
},
new()
{
Name = "Precision",
Type = LaunchOptionType.Bool,
Options = new List<string>
{ {
"--precision auto", RelativeRootPath + "/autoimport/embedding"
"--precision float16", },
"--precision float32", [SharedFolderType.ControlNet] = new[] { RelativeRootPath + "/autoimport/controlnet" },
} };
},
new() // https://github.com/invoke-ai/InvokeAI/blob/main/docs/features/CONFIGURATION.md
public override List<LaunchOptionDefinition> LaunchOptions =>
new List<LaunchOptionDefinition>
{ {
Name = "Aggressively free up GPU memory after each operation", new()
Type = LaunchOptionType.Bool, {
Options = new List<string> {"--free_gpu_mem"} Name = "Host",
}, Type = LaunchOptionType.String,
LaunchOptionDefinition.Extras DefaultValue = "localhost",
}; Options = new List<string> { "--host" }
},
new()
{
Name = "Port",
Type = LaunchOptionType.String,
DefaultValue = "9090",
Options = new List<string> { "--port" }
},
new()
{
Name = "Allow Origins",
Description =
"List of host names or IP addresses that are allowed to connect to the "
+ "InvokeAI API in the format ['host1','host2',...]",
Type = LaunchOptionType.String,
DefaultValue = "[]",
Options = new List<string> { "--allow-origins" }
},
new()
{
Name = "Always use CPU",
Type = LaunchOptionType.Bool,
Options = new List<string> { "--always_use_cpu" }
},
new()
{
Name = "Precision",
Type = LaunchOptionType.Bool,
Options = new List<string>
{
"--precision auto",
"--precision float16",
"--precision float32",
}
},
new()
{
Name = "Aggressively free up GPU memory after each operation",
Type = LaunchOptionType.Bool,
Options = new List<string> { "--free_gpu_mem" }
},
LaunchOptionDefinition.Extras
};
public override Task<string> GetLatestVersion() => Task.FromResult("main"); public override Task<string> GetLatestVersion() => Task.FromResult("main");
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] public override IEnumerable<TorchVersion> AvailableTorchVersions =>
{ new[] { TorchVersion.Cpu, TorchVersion.Cuda, TorchVersion.Rocm, TorchVersion.Mps };
TorchVersion.Cpu,
TorchVersion.Cuda,
TorchVersion.Rocm,
TorchVersion.Mps
};
public override TorchVersion GetRecommendedTorchVersion() public override TorchVersion GetRecommendedTorchVersion()
{ {
@ -140,14 +138,20 @@ public class InvokeAI : BaseGitPackage
return base.GetRecommendedTorchVersion(); return base.GetRecommendedTorchVersion();
} }
public override Task DownloadPackage(string installLocation, public override Task DownloadPackage(
DownloadPackageVersionOptions downloadOptions, IProgress<ProgressReport>? progress = null) string installLocation,
DownloadPackageVersionOptions downloadOptions,
IProgress<ProgressReport>? progress = null
)
{ {
return Task.CompletedTask; return Task.CompletedTask;
} }
public override async Task InstallPackage(string installLocation, TorchVersion torchVersion, public override async Task InstallPackage(
IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
// Setup venv // Setup venv
progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
@ -169,13 +173,13 @@ public class InvokeAI : BaseGitPackage
pipCommandArgs = pipCommandArgs =
"InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117"; "InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117";
break; break;
case TorchVersion.Rocm: case TorchVersion.Rocm:
Logger.Info("Starting InvokeAI install (ROCm)..."); Logger.Info("Starting InvokeAI install (ROCm)...");
pipCommandArgs = pipCommandArgs =
"InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; "InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2";
break; break;
case TorchVersion.Mps: case TorchVersion.Mps:
Logger.Info("Starting InvokeAI install (MPS)..."); Logger.Info("Starting InvokeAI install (MPS)...");
pipCommandArgs = "InvokeAI --use-pep517"; pipCommandArgs = "InvokeAI --use-pep517";
@ -190,15 +194,23 @@ public class InvokeAI : BaseGitPackage
progress?.Report(new ProgressReport(-1f, "Configuring InvokeAI", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Configuring InvokeAI", isIndeterminate: true));
await RunInvokeCommand(installLocation, "invokeai-configure", "--yes --skip-sd-weights", await RunInvokeCommand(
false).ConfigureAwait(false); installLocation,
"invokeai-configure",
"--yes --skip-sd-weights",
false
)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false)); progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false));
} }
public override async Task<InstalledPackageVersion> Update(InstalledPackage installedPackage, public override async Task<InstalledPackageVersion> Update(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null, InstalledPackage installedPackage,
bool includePrerelease = false) TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null,
bool includePrerelease = false
)
{ {
progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Setting up venv", isIndeterminate: true));
@ -206,8 +218,10 @@ public class InvokeAI : BaseGitPackage
{ {
throw new NullReferenceException("Installed package is missing Path and/or Version"); throw new NullReferenceException("Installed package is missing Path and/or Version");
} }
await using var venvRunner = new PyVenvRunner(Path.Combine(installedPackage.FullPath, "venv")); await using var venvRunner = new PyVenvRunner(
Path.Combine(installedPackage.FullPath, "venv")
);
venvRunner.WorkingDirectory = installedPackage.FullPath; venvRunner.WorkingDirectory = installedPackage.FullPath;
venvRunner.EnvironmentVariables = GetEnvVars(installedPackage.FullPath); venvRunner.EnvironmentVariables = GetEnvVars(installedPackage.FullPath);
@ -251,10 +265,7 @@ public class InvokeAI : BaseGitPackage
progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false)); progress?.Report(new ProgressReport(1f, "Done!", isIndeterminate: false));
return isReleaseMode return isReleaseMode
? new InstalledPackageVersion ? new InstalledPackageVersion { InstalledReleaseVersion = latestVersion }
{
InstalledReleaseVersion = latestVersion
}
: new InstalledPackageVersion : new InstalledPackageVersion
{ {
InstalledBranch = installedPackage.Version.InstalledBranch, InstalledBranch = installedPackage.Version.InstalledBranch,
@ -262,15 +273,20 @@ public class InvokeAI : BaseGitPackage
}; };
} }
public override Task public override Task RunPackage(
RunPackage(string installedPackagePath, string command, string arguments) => string installedPackagePath,
RunInvokeCommand(installedPackagePath, command, arguments, true); string command,
string arguments
) => RunInvokeCommand(installedPackagePath, command, arguments, true);
private async Task<string> GetUpdateVersion(InstalledPackage installedPackage, bool includePrerelease = false) private async Task<string> GetUpdateVersion(
InstalledPackage installedPackage,
bool includePrerelease = false
)
{ {
if (installedPackage.Version == null) if (installedPackage.Version == null)
throw new NullReferenceException("Installed package version is null"); throw new NullReferenceException("Installed package version is null");
if (installedPackage.Version.IsReleaseMode) if (installedPackage.Version.IsReleaseMode)
{ {
var releases = await GetAllReleases().ConfigureAwait(false); var releases = await GetAllReleases().ConfigureAwait(false);
@ -284,8 +300,12 @@ public class InvokeAI : BaseGitPackage
return latestCommit.Sha; return latestCommit.Sha;
} }
private async Task RunInvokeCommand(string installedPackagePath, string command, private async Task RunInvokeCommand(
string arguments, bool runDetached) string installedPackagePath,
string command,
string arguments,
bool runDetached
)
{ {
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
@ -304,12 +324,12 @@ public class InvokeAI : BaseGitPackage
// Split at ':' to get package and function // Split at ':' to get package and function
var split = entryPoint?.Split(':'); var split = entryPoint?.Split(':');
if (split is not {Length: > 1}) if (split is not { Length: > 1 })
{ {
throw new Exception($"Could not find entry point for InvokeAI: {entryPoint.ToRepr()}"); throw new Exception($"Could not find entry point for InvokeAI: {entryPoint.ToRepr()}");
} }
// Compile a startup command according to // Compile a startup command according to
// https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts // https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts
// For invokeai, also patch the shutil.get_terminal_size function to return a fixed value // For invokeai, also patch the shutil.get_terminal_size function to return a fixed value
// above the minimum in invokeai.frontend.install.widgets // above the minimum in invokeai.frontend.install.widgets
@ -326,28 +346,30 @@ public class InvokeAI : BaseGitPackage
{ {
OnConsoleOutput(s); OnConsoleOutput(s);
if (!s.Text.Contains("running on", StringComparison.OrdinalIgnoreCase)) if (!s.Text.Contains("running on", StringComparison.OrdinalIgnoreCase))
return; return;
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
var match = regex.Match(s.Text); var match = regex.Match(s.Text);
if (!match.Success) if (!match.Success)
return; return;
WebUrl = match.Value; WebUrl = match.Value;
OnStartupComplete(WebUrl); OnStartupComplete(WebUrl);
} }
VenvRunner.RunDetached($"-c \"{code}\" {arguments}".TrimEnd(), HandleConsoleOutput, OnExit); VenvRunner.RunDetached(
$"-c \"{code}\" {arguments}".TrimEnd(),
HandleConsoleOutput,
OnExit
);
} }
else else
{ {
var result = await VenvRunner.Run($"-c \"{code}\" {arguments}".TrimEnd()) var result = await VenvRunner
.Run($"-c \"{code}\" {arguments}".TrimEnd())
.ConfigureAwait(false); .ConfigureAwait(false);
OnConsoleOutput(new ProcessOutput OnConsoleOutput(new ProcessOutput { Text = result.StandardOutput });
{
Text = result.StandardOutput
});
} }
} }

84
StabilityMatrix.Core/Models/Packages/UnknownPackage.cs

@ -14,30 +14,32 @@ public class UnknownPackage : BasePackage
public override string GithubUrl => ""; public override string GithubUrl => "";
public override string LicenseType => "AGPL-3.0"; public override string LicenseType => "AGPL-3.0";
public override string LicenseUrl => public override string LicenseUrl =>
"https://github.com/LykosAI/StabilityMatrix/blob/main/LICENSE"; "https://github.com/LykosAI/StabilityMatrix/blob/main/LICENSE";
public override string Blurb => "A dank interface for diffusion"; public override string Blurb => "A dank interface for diffusion";
public override string LaunchCommand => "test"; public override string LaunchCommand => "test";
public override Uri PreviewImageUri => new(""); public override Uri PreviewImageUri => new("");
public override IReadOnlyList<string> ExtraLaunchCommands => new[] public override IReadOnlyList<string> ExtraLaunchCommands => new[] { "test-config", };
{
"test-config", public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
};
public override Task DownloadPackage(
public override SharedFolderMethod RecommendedSharedFolderMethod => string installLocation,
SharedFolderMethod.Symlink; DownloadPackageVersionOptions versionOptions,
IProgress<ProgressReport>? progress1
public override Task DownloadPackage(string installLocation, DownloadPackageVersionOptions versionOptions, )
IProgress<ProgressReport>? progress1)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task InstallPackage(string installLocation, TorchVersion torchVersion, public override Task InstallPackage(
IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
@ -48,33 +50,34 @@ public class UnknownPackage : BasePackage
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task SetupModelFolders(DirectoryPath installDirectory, public override Task SetupModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task UpdateModelFolders(DirectoryPath installDirectory, public override Task UpdateModelFolders(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task RemoveModelFolderLinks(DirectoryPath installDirectory, public override Task RemoveModelFolderLinks(
SharedFolderMethod sharedFolderMethod) DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] public override IEnumerable<TorchVersion> AvailableTorchVersions =>
{ new[] { TorchVersion.Cuda, TorchVersion.Cpu, TorchVersion.Rocm, TorchVersion.DirectMl };
TorchVersion.Cuda,
TorchVersion.Cpu,
TorchVersion.Rocm,
TorchVersion.DirectMl
};
/// <inheritdoc /> /// <inheritdoc />
public override void Shutdown() public override void Shutdown()
@ -95,28 +98,39 @@ public class UnknownPackage : BasePackage
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<InstalledPackageVersion> Update(InstalledPackage installedPackage, public override Task<InstalledPackageVersion> Update(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null, InstalledPackage installedPackage,
bool includePrerelease = false) TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null,
bool includePrerelease = false
)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<IEnumerable<Release>> GetReleaseTags() => Task.FromResult(Enumerable.Empty<Release>()); public override Task<IEnumerable<Release>> GetReleaseTags() =>
Task.FromResult(Enumerable.Empty<Release>());
public override List<LaunchOptionDefinition> LaunchOptions => new(); public override List<LaunchOptionDefinition> LaunchOptions => new();
public override Task<string> GetLatestVersion() => Task.FromResult(string.Empty); public override Task<string> GetLatestVersion() => Task.FromResult(string.Empty);
public override Task<PackageVersionOptions> GetAllVersionOptions() => public override Task<PackageVersionOptions> GetAllVersionOptions() =>
Task.FromResult(new PackageVersionOptions()); Task.FromResult(new PackageVersionOptions());
/// <inheritdoc /> /// <inheritdoc />
public override Task<IEnumerable<GitCommit>?> GetAllCommits(string branch, int page = 1, int perPage = 10) => Task.FromResult<IEnumerable<GitCommit>?>(null); public override Task<IEnumerable<GitCommit>?> GetAllCommits(
string branch,
int page = 1,
int perPage = 10
) => Task.FromResult<IEnumerable<GitCommit>?>(null);
/// <inheritdoc /> /// <inheritdoc />
public override Task<IEnumerable<Branch>> GetAllBranches() => Task.FromResult(Enumerable.Empty<Branch>()); public override Task<IEnumerable<Branch>> GetAllBranches() =>
Task.FromResult(Enumerable.Empty<Branch>());
/// <inheritdoc /> /// <inheritdoc />
public override Task<IEnumerable<Release>> GetAllReleases() => Task.FromResult(Enumerable.Empty<Release>()); public override Task<IEnumerable<Release>> GetAllReleases() =>
Task.FromResult(Enumerable.Empty<Release>());
} }

387
StabilityMatrix.Core/Models/Packages/VladAutomatic.cs

@ -17,12 +17,12 @@ namespace StabilityMatrix.Core.Models.Packages;
public class VladAutomatic : BaseGitPackage public class VladAutomatic : BaseGitPackage
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
public override string Name => "automatic"; public override string Name => "automatic";
public override string DisplayName { get; set; } = "SD.Next Web UI"; public override string DisplayName { get; set; } = "SD.Next Web UI";
public override string Author => "vladmandic"; public override string Author => "vladmandic";
public override string LicenseType => "AGPL-3.0"; public override string LicenseType => "AGPL-3.0";
public override string LicenseUrl => public override string LicenseUrl =>
"https://github.com/vladmandic/automatic/blob/master/LICENSE.txt"; "https://github.com/vladmandic/automatic/blob/master/LICENSE.txt";
public override string Blurb => "Stable Diffusion implementation with advanced features"; public override string Blurb => "Stable Diffusion implementation with advanced features";
public override string LaunchCommand => "launch.py"; public override string LaunchCommand => "launch.py";
@ -30,129 +30,133 @@ public class VladAutomatic : BaseGitPackage
public override Uri PreviewImageUri => public override Uri PreviewImageUri =>
new("https://github.com/vladmandic/automatic/raw/master/html/black-orange.jpg"); new("https://github.com/vladmandic/automatic/raw/master/html/black-orange.jpg");
public override bool ShouldIgnoreReleases => true; public override bool ShouldIgnoreReleases => true;
public override SharedFolderMethod RecommendedSharedFolderMethod =>
SharedFolderMethod.Symlink;
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
{
TorchVersion.Cpu,
TorchVersion.Rocm,
TorchVersion.DirectMl,
TorchVersion.Cuda
};
public VladAutomatic(IGithubApiCache githubApi, ISettingsManager settingsManager, IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper) :
base(githubApi, settingsManager, downloadService, prerequisiteHelper)
{
}
// https://github.com/vladmandic/automatic/blob/master/modules/shared.py#L324 public override IEnumerable<TorchVersion> AvailableTorchVersions =>
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders => new() new[] { TorchVersion.Cpu, TorchVersion.Rocm, TorchVersion.DirectMl, TorchVersion.Cuda };
{
[SharedFolderType.StableDiffusion] = new[] {"models/Stable-diffusion"},
[SharedFolderType.Diffusers] = new[] {"models/Diffusers"},
[SharedFolderType.VAE] = new[] {"models/VAE"},
[SharedFolderType.TextualInversion] = new[] {"models/embeddings"},
[SharedFolderType.Hypernetwork] = new[] {"models/hypernetworks"},
[SharedFolderType.Codeformer] = new[] {"models/Codeformer"},
[SharedFolderType.GFPGAN] = new[] {"models/GFPGAN"},
[SharedFolderType.BSRGAN] = new[] {"models/BSRGAN"},
[SharedFolderType.ESRGAN] = new[] {"models/ESRGAN"},
[SharedFolderType.RealESRGAN] = new[] {"models/RealESRGAN"},
[SharedFolderType.ScuNET] = new[] {"models/ScuNET"},
[SharedFolderType.SwinIR] = new[] {"models/SwinIR"},
[SharedFolderType.LDSR] = new[] {"models/LDSR"},
[SharedFolderType.CLIP] = new[] {"models/CLIP"},
[SharedFolderType.Lora] = new[] {"models/Lora"},
[SharedFolderType.LyCORIS] = new[] {"models/LyCORIS"},
[SharedFolderType.ControlNet] = new[] {"models/ControlNet"}
};
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] public VladAutomatic(
public override List<LaunchOptionDefinition> LaunchOptions => new() IGithubApiCache githubApi,
{ ISettingsManager settingsManager,
new() IDownloadService downloadService,
{ IPrerequisiteHelper prerequisiteHelper
Name = "Host", )
Type = LaunchOptionType.String, : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
DefaultValue = "localhost",
Options = new() {"--server-name"} // https://github.com/vladmandic/automatic/blob/master/modules/shared.py#L324
}, public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
new() new()
{ {
Name = "Port", [SharedFolderType.StableDiffusion] = new[] { "models/Stable-diffusion" },
Type = LaunchOptionType.String, [SharedFolderType.Diffusers] = new[] { "models/Diffusers" },
DefaultValue = "7860", [SharedFolderType.VAE] = new[] { "models/VAE" },
Options = new() {"--port"} [SharedFolderType.TextualInversion] = new[] { "models/embeddings" },
}, [SharedFolderType.Hypernetwork] = new[] { "models/hypernetworks" },
[SharedFolderType.Codeformer] = new[] { "models/Codeformer" },
[SharedFolderType.GFPGAN] = new[] { "models/GFPGAN" },
[SharedFolderType.BSRGAN] = new[] { "models/BSRGAN" },
[SharedFolderType.ESRGAN] = new[] { "models/ESRGAN" },
[SharedFolderType.RealESRGAN] = new[] { "models/RealESRGAN" },
[SharedFolderType.ScuNET] = new[] { "models/ScuNET" },
[SharedFolderType.SwinIR] = new[] { "models/SwinIR" },
[SharedFolderType.LDSR] = new[] { "models/LDSR" },
[SharedFolderType.CLIP] = new[] { "models/CLIP" },
[SharedFolderType.Lora] = new[] { "models/Lora" },
[SharedFolderType.LyCORIS] = new[] { "models/LyCORIS" },
[SharedFolderType.ControlNet] = new[] { "models/ControlNet" }
};
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List<LaunchOptionDefinition> LaunchOptions =>
new() new()
{ {
Name = "VRAM", new()
Type = LaunchOptionType.Bool,
InitialValue = HardwareHelper.IterGpuInfo().Select(gpu => gpu.MemoryLevel).Max() switch
{ {
Level.Low => "--lowvram", Name = "Host",
Level.Medium => "--medvram", Type = LaunchOptionType.String,
_ => null DefaultValue = "localhost",
Options = new() { "--server-name" }
}, },
Options = new() { "--lowvram", "--medvram" } new()
}, {
new() Name = "Port",
{ Type = LaunchOptionType.String,
Name = "Force use of Intel OneAPI XPU backend", DefaultValue = "7860",
Type = LaunchOptionType.Bool, Options = new() { "--port" }
Options = new() { "--use-ipex" } },
}, new()
new() {
{ Name = "VRAM",
Name = "Use DirectML if no compatible GPU is detected", Type = LaunchOptionType.Bool,
Type = LaunchOptionType.Bool, InitialValue = HardwareHelper
InitialValue = HardwareHelper.PreferDirectML(), .IterGpuInfo()
Options = new() { "--use-directml" } .Select(gpu => gpu.MemoryLevel)
}, .Max() switch
new() {
{ Level.Low => "--lowvram",
Name = "Force use of Nvidia CUDA backend", Level.Medium => "--medvram",
Type = LaunchOptionType.Bool, _ => null
InitialValue = HardwareHelper.HasNvidiaGpu(), },
Options = new() { "--use-cuda" } Options = new() { "--lowvram", "--medvram" }
}, },
new() new()
{ {
Name = "Force use of AMD ROCm backend", Name = "Force use of Intel OneAPI XPU backend",
Type = LaunchOptionType.Bool, Type = LaunchOptionType.Bool,
InitialValue = HardwareHelper.PreferRocm(), Options = new() { "--use-ipex" }
Options = new() { "--use-rocm" } },
}, new()
new() {
{ Name = "Use DirectML if no compatible GPU is detected",
Name = "CUDA Device ID", Type = LaunchOptionType.Bool,
Type = LaunchOptionType.String, InitialValue = HardwareHelper.PreferDirectML(),
Options = new() { "--device-id" } Options = new() { "--use-directml" }
}, },
new() new()
{ {
Name = "API", Name = "Force use of Nvidia CUDA backend",
Type = LaunchOptionType.Bool, Type = LaunchOptionType.Bool,
Options = new() { "--api" } InitialValue = HardwareHelper.HasNvidiaGpu(),
}, Options = new() { "--use-cuda" }
new() },
{ new()
Name = "Debug Logging", {
Type = LaunchOptionType.Bool, Name = "Force use of AMD ROCm backend",
Options = new() { "--debug" } Type = LaunchOptionType.Bool,
}, InitialValue = HardwareHelper.PreferRocm(),
LaunchOptionDefinition.Extras Options = new() { "--use-rocm" }
}; },
new()
{
Name = "CUDA Device ID",
Type = LaunchOptionType.String,
Options = new() { "--device-id" }
},
new()
{
Name = "API",
Type = LaunchOptionType.Bool,
Options = new() { "--api" }
},
new()
{
Name = "Debug Logging",
Type = LaunchOptionType.Bool,
Options = new() { "--debug" }
},
LaunchOptionDefinition.Extras
};
public override string ExtraLaunchArguments => ""; public override string ExtraLaunchArguments => "";
public override Task<string> GetLatestVersion() => Task.FromResult("master"); public override Task<string> GetLatestVersion() => Task.FromResult("master");
public override async Task InstallPackage(string installLocation, TorchVersion torchVersion, public override async Task InstallPackage(
IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
progress?.Report(new ProgressReport(-1f, "Installing package...", isIndeterminate: true)); progress?.Report(new ProgressReport(-1f, "Installing package...", isIndeterminate: true));
// Setup venv // Setup venv
@ -166,11 +170,13 @@ public class VladAutomatic : BaseGitPackage
{ {
// Run initial install // Run initial install
case TorchVersion.Cuda: case TorchVersion.Cuda:
await venvRunner.CustomInstall("launch.py --use-cuda --debug --test", OnConsoleOutput) await venvRunner
.CustomInstall("launch.py --use-cuda --debug --test", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
break; break;
case TorchVersion.Rocm: case TorchVersion.Rocm:
await venvRunner.CustomInstall("launch.py --use-rocm --debug --test", OnConsoleOutput) await venvRunner
.CustomInstall("launch.py --use-rocm --debug --test", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
break; break;
case TorchVersion.DirectMl: case TorchVersion.DirectMl:
@ -180,7 +186,8 @@ public class VladAutomatic : BaseGitPackage
break; break;
default: default:
// CPU // CPU
await venvRunner.CustomInstall("launch.py --debug --test", OnConsoleOutput) await venvRunner
.CustomInstall("launch.py --debug --test", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
break; break;
} }
@ -188,11 +195,20 @@ public class VladAutomatic : BaseGitPackage
progress?.Report(new ProgressReport(1f, isIndeterminate: false)); progress?.Report(new ProgressReport(1f, isIndeterminate: false));
} }
public override async Task DownloadPackage(string installLocation, public override async Task DownloadPackage(
DownloadPackageVersionOptions downloadOptions, IProgress<ProgressReport>? progress = null) string installLocation,
DownloadPackageVersionOptions downloadOptions,
IProgress<ProgressReport>? progress = null
)
{ {
progress?.Report(new ProgressReport(-1f, message: "Downloading package...", progress?.Report(
isIndeterminate: true, type: ProgressType.Download)); new ProgressReport(
-1f,
message: "Downloading package...",
isIndeterminate: true,
type: ProgressType.Download
)
);
var installDir = new DirectoryPath(installLocation); var installDir = new DirectoryPath(installLocation);
installDir.Create(); installDir.Create();
@ -200,22 +216,38 @@ public class VladAutomatic : BaseGitPackage
if (!string.IsNullOrWhiteSpace(downloadOptions.CommitHash)) if (!string.IsNullOrWhiteSpace(downloadOptions.CommitHash))
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit(installDir.Parent ?? "", "clone", "https://github.com/vladmandic/automatic", .RunGit(
installDir.Name).ConfigureAwait(false); installDir.Parent ?? "",
"clone",
"https://github.com/vladmandic/automatic",
installDir.Name
)
.ConfigureAwait(false);
await PrerequisiteHelper.RunGit(installLocation, "checkout", downloadOptions.CommitHash) await PrerequisiteHelper
.RunGit(installLocation, "checkout", downloadOptions.CommitHash)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
else if (!string.IsNullOrWhiteSpace(downloadOptions.BranchName)) else if (!string.IsNullOrWhiteSpace(downloadOptions.BranchName))
{ {
await PrerequisiteHelper await PrerequisiteHelper
.RunGit(installDir.Parent ?? "", "clone", "-b", downloadOptions.BranchName, .RunGit(
"https://github.com/vladmandic/automatic", installDir.Name) installDir.Parent ?? "",
"clone",
"-b",
downloadOptions.BranchName,
"https://github.com/vladmandic/automatic",
installDir.Name
)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
} }
public override async Task RunPackage(string installedPackagePath, string command, string arguments) public override async Task RunPackage(
string installedPackagePath,
string command,
string arguments
)
{ {
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
@ -245,35 +277,44 @@ public class VladAutomatic : BaseGitPackage
VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, HandleExit); VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, HandleExit);
} }
public override async Task<InstalledPackageVersion> Update(InstalledPackage installedPackage, public override async Task<InstalledPackageVersion> Update(
TorchVersion torchVersion, IProgress<ProgressReport>? progress = null, InstalledPackage installedPackage,
bool includePrerelease = false) TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null,
bool includePrerelease = false
)
{ {
if (installedPackage.Version is null) if (installedPackage.Version is null)
{ {
throw new Exception("Version is null"); throw new Exception("Version is null");
} }
progress?.Report(new ProgressReport(-1f, message: "Downloading package update...", progress?.Report(
isIndeterminate: true, type: ProgressType.Update)); new ProgressReport(
-1f,
await PrerequisiteHelper.RunGit(installedPackage.FullPath, "checkout", message: "Downloading package update...",
installedPackage.Version.InstalledBranch).ConfigureAwait(false); isIndeterminate: true,
type: ProgressType.Update
)
);
await PrerequisiteHelper
.RunGit(installedPackage.FullPath, "checkout", installedPackage.Version.InstalledBranch)
.ConfigureAwait(false);
var venvRunner = new PyVenvRunner(Path.Combine(installedPackage.FullPath!, "venv")); var venvRunner = new PyVenvRunner(Path.Combine(installedPackage.FullPath!, "venv"));
venvRunner.WorkingDirectory = installedPackage.FullPath!; venvRunner.WorkingDirectory = installedPackage.FullPath!;
venvRunner.EnvironmentVariables = SettingsManager.Settings.EnvironmentVariables; venvRunner.EnvironmentVariables = SettingsManager.Settings.EnvironmentVariables;
await venvRunner.CustomInstall("launch.py --upgrade --test", OnConsoleOutput) await venvRunner
.CustomInstall("launch.py --upgrade --test", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
try try
{ {
var output = var output = await PrerequisiteHelper
await PrerequisiteHelper .GetGitOutput(installedPackage.FullPath, "rev-parse", "HEAD")
.GetGitOutput(installedPackage.FullPath, "rev-parse", "HEAD") .ConfigureAwait(false);
.ConfigureAwait(false);
return new InstalledPackageVersion return new InstalledPackageVersion
{ {
@ -287,9 +328,14 @@ public class VladAutomatic : BaseGitPackage
} }
finally finally
{ {
progress?.Report(new ProgressReport(1f, message: "Update Complete", progress?.Report(
isIndeterminate: false, new ProgressReport(
type: ProgressType.Update)); 1f,
message: "Update Complete",
isIndeterminate: false,
type: ProgressType.Update
)
);
} }
return new InstalledPackageVersion return new InstalledPackageVersion
@ -298,7 +344,10 @@ public class VladAutomatic : BaseGitPackage
}; };
} }
public override Task SetupModelFolders(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod) public override Task SetupModelFolders(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
{ {
switch (sharedFolderMethod) switch (sharedFolderMethod)
{ {
@ -335,38 +384,56 @@ public class VladAutomatic : BaseGitPackage
configRoot["vae_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "VAE"); configRoot["vae_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "VAE");
configRoot["lora_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "Lora"); configRoot["lora_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "Lora");
configRoot["lyco_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "LyCORIS"); configRoot["lyco_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "LyCORIS");
configRoot["embeddings_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "TextualInversion"); configRoot["embeddings_dir"] = Path.Combine(
configRoot["hypernetwork_dir"] = Path.Combine(SettingsManager.ModelsDirectory, "Hypernetwork"); SettingsManager.ModelsDirectory,
configRoot["codeformer_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "Codeformer"); "TextualInversion"
);
configRoot["hypernetwork_dir"] = Path.Combine(
SettingsManager.ModelsDirectory,
"Hypernetwork"
);
configRoot["codeformer_models_path"] = Path.Combine(
SettingsManager.ModelsDirectory,
"Codeformer"
);
configRoot["gfpgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "GFPGAN"); configRoot["gfpgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "GFPGAN");
configRoot["bsrgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "BSRGAN"); configRoot["bsrgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "BSRGAN");
configRoot["esrgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "ESRGAN"); configRoot["esrgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "ESRGAN");
configRoot["realesrgan_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "RealESRGAN"); configRoot["realesrgan_models_path"] = Path.Combine(
SettingsManager.ModelsDirectory,
"RealESRGAN"
);
configRoot["scunet_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "ScuNET"); configRoot["scunet_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "ScuNET");
configRoot["swinir_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "SwinIR"); configRoot["swinir_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "SwinIR");
configRoot["ldsr_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "LDSR"); configRoot["ldsr_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "LDSR");
configRoot["clip_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "CLIP"); configRoot["clip_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "CLIP");
configRoot["control_net_models_path"] = Path.Combine(SettingsManager.ModelsDirectory, "ControlNet"); configRoot["control_net_models_path"] = Path.Combine(
SettingsManager.ModelsDirectory,
var configJsonStr = JsonSerializer.Serialize(configRoot, new JsonSerializerOptions "ControlNet"
{ );
WriteIndented = true
}); var configJsonStr = JsonSerializer.Serialize(
configRoot,
new JsonSerializerOptions { WriteIndented = true }
);
File.WriteAllText(configJsonPath, configJsonStr); File.WriteAllText(configJsonPath, configJsonStr);
return Task.CompletedTask; return Task.CompletedTask;
} }
public override Task UpdateModelFolders(DirectoryPath installDirectory, public override Task UpdateModelFolders(
SharedFolderMethod sharedFolderMethod) => DirectoryPath installDirectory,
SetupModelFolders(installDirectory, sharedFolderMethod); SharedFolderMethod sharedFolderMethod
) => SetupModelFolders(installDirectory, sharedFolderMethod);
public override Task RemoveModelFolderLinks(DirectoryPath installDirectory, public override Task RemoveModelFolderLinks(
SharedFolderMethod sharedFolderMethod) => DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
) =>
sharedFolderMethod switch sharedFolderMethod switch
{ {
SharedFolderMethod.Symlink => base.RemoveModelFolderLinks(installDirectory, SharedFolderMethod.Symlink
sharedFolderMethod), => base.RemoveModelFolderLinks(installDirectory, sharedFolderMethod),
SharedFolderMethod.None => Task.CompletedTask, SharedFolderMethod.None => Task.CompletedTask,
_ => Task.CompletedTask _ => Task.CompletedTask
}; };

226
StabilityMatrix.Core/Models/Packages/VoltaML.cs

@ -14,128 +14,130 @@ public class VoltaML : BaseGitPackage
public override string DisplayName { get; set; } = "VoltaML"; public override string DisplayName { get; set; } = "VoltaML";
public override string Author => "VoltaML"; public override string Author => "VoltaML";
public override string LicenseType => "GPL-3.0"; public override string LicenseType => "GPL-3.0";
public override string LicenseUrl => public override string LicenseUrl =>
"https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/License"; "https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/License";
public override string Blurb => "Fast Stable Diffusion with support for AITemplate"; public override string Blurb => "Fast Stable Diffusion with support for AITemplate";
public override string LaunchCommand => "main.py"; public override string LaunchCommand => "main.py";
public override Uri PreviewImageUri => new( public override Uri PreviewImageUri =>
"https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8"); new(
"https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8"
);
// There are releases but the manager just downloads the latest commit anyways, // There are releases but the manager just downloads the latest commit anyways,
// so we'll just limit to commit mode to be more consistent // so we'll just limit to commit mode to be more consistent
public override bool ShouldIgnoreReleases => true; public override bool ShouldIgnoreReleases => true;
public VoltaML( public VoltaML(
IGithubApiCache githubApi, IGithubApiCache githubApi,
ISettingsManager settingsManager, ISettingsManager settingsManager,
IDownloadService downloadService, IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper) : IPrerequisiteHelper prerequisiteHelper
base(githubApi, settingsManager, downloadService, prerequisiteHelper) )
{ : base(githubApi, settingsManager, downloadService, prerequisiteHelper) { }
}
// https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L86 // https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L86
public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders => new() public override Dictionary<SharedFolderType, IReadOnlyList<string>> SharedFolders =>
{ new()
[SharedFolderType.StableDiffusion] = new[] {"data/models"}, {
[SharedFolderType.Lora] = new[] {"data/lora"}, [SharedFolderType.StableDiffusion] = new[] { "data/models" },
[SharedFolderType.TextualInversion] = new[] {"data/textual-inversion"}, [SharedFolderType.Lora] = new[] { "data/lora" },
}; [SharedFolderType.TextualInversion] = new[] { "data/textual-inversion" },
};
public override SharedFolderMethod RecommendedSharedFolderMethod =>
SharedFolderMethod.Symlink;
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] {TorchVersion.None}; public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.Symlink;
public override IEnumerable<TorchVersion> AvailableTorchVersions => new[] { TorchVersion.None };
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods =>
new[] { SharedFolderMethod.Symlink, SharedFolderMethod.None };
public override IEnumerable<SharedFolderMethod> AvailableSharedFolderMethods => new[]
{
SharedFolderMethod.Symlink,
SharedFolderMethod.None
};
// https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L45 // https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/main.py#L45
public override List<LaunchOptionDefinition> LaunchOptions => new List<LaunchOptionDefinition> public override List<LaunchOptionDefinition> LaunchOptions =>
{ new List<LaunchOptionDefinition>
new()
{ {
Name = "Log Level", new()
Type = LaunchOptionType.Bool,
DefaultValue = "--log-level INFO",
Options =
{ {
"--log-level DEBUG", Name = "Log Level",
"--log-level INFO", Type = LaunchOptionType.Bool,
"--log-level WARNING", DefaultValue = "--log-level INFO",
"--log-level ERROR", Options =
"--log-level CRITICAL" {
} "--log-level DEBUG",
}, "--log-level INFO",
new() "--log-level WARNING",
{ "--log-level ERROR",
Name = "Use ngrok to expose the API", "--log-level CRITICAL"
Type = LaunchOptionType.Bool, }
Options = {"--ngrok"} },
}, new()
new()
{
Name = "Expose the API to the network",
Type = LaunchOptionType.Bool,
Options = {"--host"}
},
new()
{
Name = "Skip virtualenv check",
Type = LaunchOptionType.Bool,
InitialValue = true,
Options = {"--in-container"}
},
new()
{
Name = "Force VoltaML to use a specific type of PyTorch distribution",
Type = LaunchOptionType.Bool,
Options =
{ {
"--pytorch-type cpu", Name = "Use ngrok to expose the API",
"--pytorch-type cuda", Type = LaunchOptionType.Bool,
"--pytorch-type rocm", Options = { "--ngrok" }
"--pytorch-type directml", },
"--pytorch-type intel", new()
"--pytorch-type vulkan" {
} Name = "Expose the API to the network",
}, Type = LaunchOptionType.Bool,
new() Options = { "--host" }
{ },
Name = "Run in tandem with the Discord bot", new()
Type = LaunchOptionType.Bool, {
Options = {"--bot"} Name = "Skip virtualenv check",
}, Type = LaunchOptionType.Bool,
new() InitialValue = true,
{ Options = { "--in-container" }
Name = "Enable Cloudflare R2 bucket upload support", },
Type = LaunchOptionType.Bool, new()
Options = {"--enable-r2"} {
}, Name = "Force VoltaML to use a specific type of PyTorch distribution",
new() Type = LaunchOptionType.Bool,
{ Options =
Name = "Port", {
Type = LaunchOptionType.String, "--pytorch-type cpu",
DefaultValue = "5003", "--pytorch-type cuda",
Options = {"--port"} "--pytorch-type rocm",
}, "--pytorch-type directml",
new() "--pytorch-type intel",
{ "--pytorch-type vulkan"
Name = "Only install requirements and exit", }
Type = LaunchOptionType.Bool, },
Options = {"--install-only"} new()
}, {
LaunchOptionDefinition.Extras Name = "Run in tandem with the Discord bot",
}; Type = LaunchOptionType.Bool,
Options = { "--bot" }
},
new()
{
Name = "Enable Cloudflare R2 bucket upload support",
Type = LaunchOptionType.Bool,
Options = { "--enable-r2" }
},
new()
{
Name = "Port",
Type = LaunchOptionType.String,
DefaultValue = "5003",
Options = { "--port" }
},
new()
{
Name = "Only install requirements and exit",
Type = LaunchOptionType.Bool,
Options = { "--install-only" }
},
LaunchOptionDefinition.Extras
};
public override Task<string> GetLatestVersion() => Task.FromResult("main"); public override Task<string> GetLatestVersion() => Task.FromResult("main");
public override async Task InstallPackage(string installLocation, TorchVersion torchVersion, public override async Task InstallPackage(
IProgress<ProgressReport>? progress = null) string installLocation,
TorchVersion torchVersion,
IProgress<ProgressReport>? progress = null
)
{ {
await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false); await base.InstallPackage(installLocation, torchVersion, progress).ConfigureAwait(false);
@ -147,24 +149,30 @@ public class VoltaML : BaseGitPackage
await venvRunner.Setup(true).ConfigureAwait(false); await venvRunner.Setup(true).ConfigureAwait(false);
// Install requirements // Install requirements
progress?.Report(new ProgressReport(-1, "Installing Package Requirements", progress?.Report(
isIndeterminate: true)); new ProgressReport(-1, "Installing Package Requirements", isIndeterminate: true)
);
await venvRunner await venvRunner
.PipInstall("rich packaging python-dotenv", OnConsoleOutput) .PipInstall("rich packaging python-dotenv", OnConsoleOutput)
.ConfigureAwait(false); .ConfigureAwait(false);
progress?.Report(new ProgressReport(1, "Installing Package Requirements", progress?.Report(
isIndeterminate: false)); new ProgressReport(1, "Installing Package Requirements", isIndeterminate: false)
);
} }
public override async Task RunPackage(string installedPackagePath, string command, string arguments) public override async Task RunPackage(
string installedPackagePath,
string command,
string arguments
)
{ {
await SetupVenv(installedPackagePath).ConfigureAwait(false); await SetupVenv(installedPackagePath).ConfigureAwait(false);
var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}"; var args = $"\"{Path.Combine(installedPackagePath, command)}\" {arguments}";
var foundIndicator = false; var foundIndicator = false;
void HandleConsoleOutput(ProcessOutput s) void HandleConsoleOutput(ProcessOutput s)
{ {
OnConsoleOutput(s); OnConsoleOutput(s);
@ -178,17 +186,17 @@ public class VoltaML : BaseGitPackage
if (!foundIndicator) if (!foundIndicator)
return; return;
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
var match = regex.Match(s.Text); var match = regex.Match(s.Text);
if (!match.Success) if (!match.Success)
return; return;
WebUrl = match.Value; WebUrl = match.Value;
OnStartupComplete(WebUrl); OnStartupComplete(WebUrl);
foundIndicator = false; foundIndicator = false;
} }
VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit); VenvRunner.RunDetached(args.TrimEnd(), HandleConsoleOutput, OnExit);
} }
} }

2
StabilityMatrix.Core/Services/IModelIndexService.cs

@ -14,7 +14,7 @@ public interface IModelIndexService
/// Get all models of the specified type from the existing index. /// Get all models of the specified type from the existing index.
/// </summary> /// </summary>
Task<IReadOnlyList<LocalModelFile>> GetModelsOfType(SharedFolderType type); Task<IReadOnlyList<LocalModelFile>> GetModelsOfType(SharedFolderType type);
/// <summary> /// <summary>
/// Starts a background task to refresh the local model file index. /// Starts a background task to refresh the local model file index.
/// </summary> /// </summary>

11
StabilityMatrix.Core/Services/ISettingsManager.cs

@ -23,7 +23,7 @@ public interface ISettingsManager
/// Will fire instantly if it is already set. /// Will fire instantly if it is already set.
/// </summary> /// </summary>
void RegisterOnLibraryDirSet(Action<string> handler); void RegisterOnLibraryDirSet(Action<string> handler);
/// <inheritdoc /> /// <inheritdoc />
SettingsTransaction BeginTransaction(); SettingsTransaction BeginTransaction();
@ -35,14 +35,17 @@ public interface ISettingsManager
/// <inheritdoc /> /// <inheritdoc />
void RelayPropertyFor<T, TValue>( void RelayPropertyFor<T, TValue>(
T source, T source,
Expression<Func<T, TValue>> sourceProperty, Expression<Func<T, TValue>> sourceProperty,
Expression<Func<Settings, TValue>> settingsProperty) where T : INotifyPropertyChanged; Expression<Func<Settings, TValue>> settingsProperty
)
where T : INotifyPropertyChanged;
/// <inheritdoc /> /// <inheritdoc />
void RegisterPropertyChangedHandler<T>( void RegisterPropertyChangedHandler<T>(
Expression<Func<Settings, T>> settingsProperty, Expression<Func<Settings, T>> settingsProperty,
Action<T> onPropertyChanged); Action<T> onPropertyChanged
);
/// <summary> /// <summary>
/// Attempts to locate and set the library path /// Attempts to locate and set the library path

71
StabilityMatrix.Core/Services/ModelIndexService.cs

@ -38,9 +38,10 @@ public class ModelIndexService : IModelIndexService
return await liteDbContext.LocalModelFiles return await liteDbContext.LocalModelFiles
.Query() .Query()
.Where(m => m.SharedFolderType == type) .Where(m => m.SharedFolderType == type)
.ToArrayAsync().ConfigureAwait(false); .ToArrayAsync()
.ConfigureAwait(false);
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task RefreshIndex() public async Task RefreshIndex()
{ {
@ -49,19 +50,18 @@ public class ModelIndexService : IModelIndexService
// Start // Start
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
logger.LogInformation("Refreshing model index..."); logger.LogInformation("Refreshing model index...");
using var db using var db = await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
= await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
var localModelFiles = db.GetCollection<LocalModelFile>("LocalModelFiles")!; var localModelFiles = db.GetCollection<LocalModelFile>("LocalModelFiles")!;
await localModelFiles.DeleteAllAsync().ConfigureAwait(false); await localModelFiles.DeleteAllAsync().ConfigureAwait(false);
// Record start of actual indexing // Record start of actual indexing
var indexStart = stopwatch.Elapsed; var indexStart = stopwatch.Elapsed;
var added = 0; var added = 0;
foreach ( foreach (
var file in modelsDir.Info var file in modelsDir.Info
.EnumerateFiles("*.*", SearchOption.AllDirectories) .EnumerateFiles("*.*", SearchOption.AllDirectories)
@ -73,61 +73,74 @@ public class ModelIndexService : IModelIndexService
{ {
continue; continue;
} }
var relativePath = Path.GetRelativePath(modelsDir, file); var relativePath = Path.GetRelativePath(modelsDir, file);
// Get shared folder name // Get shared folder name
var sharedFolderName = relativePath.Split(Path.DirectorySeparatorChar, var sharedFolderName = relativePath.Split(
StringSplitOptions.RemoveEmptyEntries)[0]; Path.DirectorySeparatorChar,
StringSplitOptions.RemoveEmptyEntries
)[0];
// Convert to enum // Convert to enum
var sharedFolderType = Enum.Parse<SharedFolderType>(sharedFolderName, true); var sharedFolderType = Enum.Parse<SharedFolderType>(sharedFolderName, true);
var localModel = new LocalModelFile var localModel = new LocalModelFile
{ {
RelativePath = relativePath, RelativePath = relativePath,
SharedFolderType = sharedFolderType, SharedFolderType = sharedFolderType,
}; };
// Try to find a connected model info // Try to find a connected model info
var jsonPath = file.Directory!.JoinFile( var jsonPath = file.Directory!.JoinFile(
new FilePath(file.NameWithoutExtension, ".cm-info.json")); new FilePath(file.NameWithoutExtension, ".cm-info.json")
);
if (jsonPath.Exists) if (jsonPath.Exists)
{ {
var connectedModelInfo = ConnectedModelInfo.FromJson( var connectedModelInfo = ConnectedModelInfo.FromJson(
await jsonPath.ReadAllTextAsync().ConfigureAwait(false)); await jsonPath.ReadAllTextAsync().ConfigureAwait(false)
);
localModel.ConnectedModelInfo = connectedModelInfo; localModel.ConnectedModelInfo = connectedModelInfo;
} }
// Try to find a preview image // Try to find a preview image
var previewImagePath = LocalModelFile.SupportedImageExtensions var previewImagePath = LocalModelFile.SupportedImageExtensions
.Select(ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}") .Select(
).FirstOrDefault(path => path.Exists); ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}")
)
.FirstOrDefault(path => path.Exists);
if (previewImagePath != null) if (previewImagePath != null)
{ {
localModel.PreviewImageRelativePath = Path.GetRelativePath(modelsDir, previewImagePath); localModel.PreviewImageRelativePath = Path.GetRelativePath(
modelsDir,
previewImagePath
);
} }
// Insert into database // Insert into database
await localModelFiles.InsertAsync(localModel).ConfigureAwait(false); await localModelFiles.InsertAsync(localModel).ConfigureAwait(false);
added++; added++;
} }
// Record end of actual indexing // Record end of actual indexing
var indexEnd = stopwatch.Elapsed; var indexEnd = stopwatch.Elapsed;
await db.CommitAsync().ConfigureAwait(false); await db.CommitAsync().ConfigureAwait(false);
// End // End
stopwatch.Stop(); stopwatch.Stop();
var indexDuration = indexEnd - indexStart; var indexDuration = indexEnd - indexStart;
var dbDuration = stopwatch.Elapsed - indexDuration; var dbDuration = stopwatch.Elapsed - indexDuration;
logger.LogInformation("Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)", logger.LogInformation(
added, indexDuration.TotalMilliseconds, dbDuration.TotalMilliseconds); "Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)",
added,
indexDuration.TotalMilliseconds,
dbDuration.TotalMilliseconds
);
} }
/// <inheritdoc /> /// <inheritdoc />

259
StabilityMatrix.Core/Services/SettingsManager.cs

@ -20,10 +20,16 @@ public class SettingsManager : ISettingsManager
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private static readonly ReaderWriterLockSlim FileLock = new(); private static readonly ReaderWriterLockSlim FileLock = new();
private static readonly string GlobalSettingsPath = Path.Combine(Compat.AppDataHome, "global.json"); private static readonly string GlobalSettingsPath = Path.Combine(
Compat.AppDataHome,
private readonly string? originalEnvPath = Environment.GetEnvironmentVariable("PATH", EnvironmentVariableTarget.Process); "global.json"
);
private readonly string? originalEnvPath = Environment.GetEnvironmentVariable(
"PATH",
EnvironmentVariableTarget.Process
);
// Library properties // Library properties
public bool IsPortableMode { get; private set; } public bool IsPortableMode { get; private set; }
private string? libraryDir; private string? libraryDir;
@ -50,10 +56,10 @@ public class SettingsManager : ISettingsManager
private string SettingsPath => Path.Combine(LibraryDir, "settings.json"); private string SettingsPath => Path.Combine(LibraryDir, "settings.json");
public string ModelsDirectory => Path.Combine(LibraryDir, "Models"); public string ModelsDirectory => Path.Combine(LibraryDir, "Models");
public string DownloadsDirectory => Path.Combine(LibraryDir, ".downloads"); public string DownloadsDirectory => Path.Combine(LibraryDir, ".downloads");
public Settings Settings { get; private set; } = new(); public Settings Settings { get; private set; } = new();
public event EventHandler<string>? LibraryDirChanged; public event EventHandler<string>? LibraryDirChanged;
public event EventHandler<PropertyChangedEventArgs>? SettingsPropertyChanged; public event EventHandler<PropertyChangedEventArgs>? SettingsPropertyChanged;
/// <inheritdoc /> /// <inheritdoc />
@ -66,7 +72,7 @@ public class SettingsManager : ISettingsManager
} }
LibraryDirChanged += Handler; LibraryDirChanged += Handler;
return; return;
void Handler(object? sender, string dir) void Handler(object? sender, string dir)
@ -75,17 +81,19 @@ public class SettingsManager : ISettingsManager
handler(dir); handler(dir);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public SettingsTransaction BeginTransaction() public SettingsTransaction BeginTransaction()
{ {
if (!IsLibraryDirSet) if (!IsLibraryDirSet)
{ {
throw new InvalidOperationException("LibraryDir not set when BeginTransaction was called"); throw new InvalidOperationException(
"LibraryDir not set when BeginTransaction was called"
);
} }
return new SettingsTransaction(this, SaveSettingsAsync); return new SettingsTransaction(this, SaveSettingsAsync);
} }
/// <inheritdoc /> /// <inheritdoc />
public void Transaction(Action<Settings> func, bool ignoreMissingLibraryDir = false) public void Transaction(Action<Settings> func, bool ignoreMissingLibraryDir = false)
{ {
@ -102,97 +110,111 @@ public class SettingsManager : ISettingsManager
func(transaction.Settings); func(transaction.Settings);
transaction.Dispose(); transaction.Dispose();
} }
/// <inheritdoc /> /// <inheritdoc />
public void Transaction<TValue>(Expression<Func<Settings, TValue>> expression, TValue value) public void Transaction<TValue>(Expression<Func<Settings, TValue>> expression, TValue value)
{ {
if (expression.Body is not MemberExpression memberExpression) if (expression.Body is not MemberExpression memberExpression)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Expression must be a member expression, not {expression.Body.NodeType}"); $"Expression must be a member expression, not {expression.Body.NodeType}"
);
} }
var propertyInfo = memberExpression.Member as PropertyInfo; var propertyInfo = memberExpression.Member as PropertyInfo;
if (propertyInfo == null) if (propertyInfo == null)
{ {
throw new ArgumentException( throw new ArgumentException(
$"Expression member must be a property, not {memberExpression.Member.MemberType}"); $"Expression member must be a property, not {memberExpression.Member.MemberType}"
);
} }
var name = propertyInfo.Name; var name = propertyInfo.Name;
// Set value // Set value
using var transaction = BeginTransaction(); using var transaction = BeginTransaction();
propertyInfo.SetValue(transaction.Settings, value); propertyInfo.SetValue(transaction.Settings, value);
// Invoke property changed event // Invoke property changed event
SettingsPropertyChanged?.Invoke(this, new PropertyChangedEventArgs(name)); SettingsPropertyChanged?.Invoke(this, new PropertyChangedEventArgs(name));
} }
/// <inheritdoc /> /// <inheritdoc />
public void RelayPropertyFor<T, TValue>( public void RelayPropertyFor<T, TValue>(
T source, T source,
Expression<Func<T, TValue>> sourceProperty, Expression<Func<T, TValue>> sourceProperty,
Expression<Func<Settings, TValue>> settingsProperty) where T : INotifyPropertyChanged Expression<Func<Settings, TValue>> settingsProperty
)
where T : INotifyPropertyChanged
{ {
var sourceGetter = sourceProperty.Compile(); var sourceGetter = sourceProperty.Compile();
var (propertyName, assigner) = Expressions.GetAssigner(sourceProperty); var (propertyName, assigner) = Expressions.GetAssigner(sourceProperty);
var sourceSetter = assigner.Compile(); var sourceSetter = assigner.Compile();
var settingsGetter = settingsProperty.Compile(); var settingsGetter = settingsProperty.Compile();
var (targetPropertyName, settingsAssigner) = Expressions.GetAssigner(settingsProperty); var (targetPropertyName, settingsAssigner) = Expressions.GetAssigner(settingsProperty);
var settingsSetter = settingsAssigner.Compile(); var settingsSetter = settingsAssigner.Compile();
var sourceTypeName = source.GetType().Name; var sourceTypeName = source.GetType().Name;
// Update source when settings change // Update source when settings change
SettingsPropertyChanged += (_, args) => SettingsPropertyChanged += (_, args) =>
{ {
if (args.PropertyName != propertyName) return; if (args.PropertyName != propertyName)
return;
Logger.Trace( Logger.Trace(
"[RelayPropertyFor] " + "[RelayPropertyFor] "
"Settings.{TargetProperty:l} -> {SourceType:l}.{SourceProperty:l}", + "Settings.{TargetProperty:l} -> {SourceType:l}.{SourceProperty:l}",
targetPropertyName, sourceTypeName, propertyName); targetPropertyName,
sourceTypeName,
propertyName
);
sourceSetter(source, settingsGetter(Settings)); sourceSetter(source, settingsGetter(Settings));
}; };
// Set and Save settings when source changes // Set and Save settings when source changes
source.PropertyChanged += (_, args) => source.PropertyChanged += (_, args) =>
{ {
if (args.PropertyName != propertyName) return; if (args.PropertyName != propertyName)
return;
Logger.Trace( Logger.Trace(
"[RelayPropertyFor] " + "[RelayPropertyFor] "
"{SourceType:l}.{SourceProperty:l} -> Settings.{TargetProperty:l}", + "{SourceType:l}.{SourceProperty:l} -> Settings.{TargetProperty:l}",
sourceTypeName, propertyName, targetPropertyName); sourceTypeName,
propertyName,
targetPropertyName
);
settingsSetter(Settings, sourceGetter(source)); settingsSetter(Settings, sourceGetter(source));
SaveSettingsAsync().SafeFireAndForget(); SaveSettingsAsync().SafeFireAndForget();
// Invoke property changed event // Invoke property changed event
SettingsPropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName)); SettingsPropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName));
}; };
} }
/// <inheritdoc /> /// <inheritdoc />
public void RegisterPropertyChangedHandler<T>( public void RegisterPropertyChangedHandler<T>(
Expression<Func<Settings, T>> settingsProperty, Expression<Func<Settings, T>> settingsProperty,
Action<T> onPropertyChanged) Action<T> onPropertyChanged
)
{ {
var settingsGetter = settingsProperty.Compile(); var settingsGetter = settingsProperty.Compile();
var (propertyName, _) = Expressions.GetAssigner(settingsProperty); var (propertyName, _) = Expressions.GetAssigner(settingsProperty);
// Invoke handler when settings change // Invoke handler when settings change
SettingsPropertyChanged += (_, args) => SettingsPropertyChanged += (_, args) =>
{ {
if (args.PropertyName != propertyName) return; if (args.PropertyName != propertyName)
return;
onPropertyChanged(settingsGetter(Settings)); onPropertyChanged(settingsGetter(Settings));
}; };
} }
/// <summary> /// <summary>
/// Attempts to locate and set the library path /// Attempts to locate and set the library path
/// Return true if found, false otherwise /// Return true if found, false otherwise
@ -209,18 +231,21 @@ public class SettingsManager : ISettingsManager
LoadSettings(); LoadSettings();
return true; return true;
} }
// 2. Check %APPDATA%/StabilityMatrix/library.json // 2. Check %APPDATA%/StabilityMatrix/library.json
FilePath libraryJsonFile = Compat.AppDataHome + "library.json"; FilePath libraryJsonFile = Compat.AppDataHome + "library.json";
if (!libraryJsonFile.Exists) return false; if (!libraryJsonFile.Exists)
return false;
try try
{ {
var libraryJson = libraryJsonFile.ReadAllText(); var libraryJson = libraryJsonFile.ReadAllText();
var librarySettings = JsonSerializer.Deserialize<LibrarySettings>(libraryJson); var librarySettings = JsonSerializer.Deserialize<LibrarySettings>(libraryJson);
if (!string.IsNullOrWhiteSpace(librarySettings?.LibraryPath) if (
&& Directory.Exists(librarySettings?.LibraryPath)) !string.IsNullOrWhiteSpace(librarySettings?.LibraryPath)
&& Directory.Exists(librarySettings?.LibraryPath)
)
{ {
LibraryDir = librarySettings.LibraryPath; LibraryDir = librarySettings.LibraryPath;
SetStaticLibraryPaths(); SetStaticLibraryPaths();
@ -252,13 +277,16 @@ public class SettingsManager : ISettingsManager
var libraryJsonFile = Compat.AppDataHome.JoinFile("library.json"); var libraryJsonFile = Compat.AppDataHome.JoinFile("library.json");
var library = new LibrarySettings { LibraryPath = path }; var library = new LibrarySettings { LibraryPath = path };
var libraryJson = JsonSerializer.Serialize(library, new JsonSerializerOptions { WriteIndented = true }); var libraryJson = JsonSerializer.Serialize(
library,
new JsonSerializerOptions { WriteIndented = true }
);
libraryJsonFile.WriteAllText(libraryJson); libraryJsonFile.WriteAllText(libraryJson);
// actually create the LibraryPath directory // actually create the LibraryPath directory
Directory.CreateDirectory(path); Directory.CreateDirectory(path);
} }
/// <summary> /// <summary>
/// Enable and create settings files for portable mode /// Enable and create settings files for portable mode
/// Creates the ./Data directory and the `.sm-portable` marker file /// Creates the ./Data directory and the `.sm-portable` marker file
@ -280,18 +308,21 @@ public class SettingsManager : ISettingsManager
/// </summary> /// </summary>
public IEnumerable<InstalledPackage> GetOldInstalledPackages() public IEnumerable<InstalledPackage> GetOldInstalledPackages()
{ {
var oldSettingsPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), var oldSettingsPath = Path.Combine(
"StabilityMatrix", "settings.json"); Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData),
"StabilityMatrix",
"settings.json"
);
if (!File.Exists(oldSettingsPath)) if (!File.Exists(oldSettingsPath))
yield break; yield break;
var oldSettingsJson = File.ReadAllText(oldSettingsPath); var oldSettingsJson = File.ReadAllText(oldSettingsPath);
var oldSettings = JsonSerializer.Deserialize<Settings>(oldSettingsJson, new JsonSerializerOptions var oldSettings = JsonSerializer.Deserialize<Settings>(
{ oldSettingsJson,
Converters = { new JsonStringEnumConverter() } new JsonSerializerOptions { Converters = { new JsonStringEnumConverter() } }
}); );
// Absolute paths are old formats requiring migration // Absolute paths are old formats requiring migration
#pragma warning disable CS0618 #pragma warning disable CS0618
var oldPackages = oldSettings?.InstalledPackages.Where(package => package.Path != null); var oldPackages = oldSettings?.InstalledPackages.Where(package => package.Path != null);
@ -299,7 +330,7 @@ public class SettingsManager : ISettingsManager
if (oldPackages == null) if (oldPackages == null)
yield break; yield break;
foreach (var package in oldPackages) foreach (var package in oldPackages)
{ {
yield return package; yield return package;
@ -308,24 +339,27 @@ public class SettingsManager : ISettingsManager
public Guid GetOldActivePackageId() public Guid GetOldActivePackageId()
{ {
var oldSettingsPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), var oldSettingsPath = Path.Combine(
"StabilityMatrix", "settings.json"); Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData),
"StabilityMatrix",
"settings.json"
);
if (!File.Exists(oldSettingsPath)) if (!File.Exists(oldSettingsPath))
return default; return default;
var oldSettingsJson = File.ReadAllText(oldSettingsPath); var oldSettingsJson = File.ReadAllText(oldSettingsPath);
var oldSettings = JsonSerializer.Deserialize<Settings>(oldSettingsJson, new JsonSerializerOptions var oldSettings = JsonSerializer.Deserialize<Settings>(
{ oldSettingsJson,
Converters = { new JsonStringEnumConverter() } new JsonSerializerOptions { Converters = { new JsonStringEnumConverter() } }
}); );
if (oldSettings == null) if (oldSettings == null)
return default; return default;
return oldSettings.ActiveInstalledPackageId ?? default; return oldSettings.ActiveInstalledPackageId ?? default;
} }
public void AddPathExtension(string pathExtension) public void AddPathExtension(string pathExtension)
{ {
Settings.PathExtensions ??= new List<string>(); Settings.PathExtensions ??= new List<string>();
@ -337,13 +371,14 @@ public class SettingsManager : ISettingsManager
{ {
return string.Join(";", Settings.PathExtensions ?? new List<string>()); return string.Join(";", Settings.PathExtensions ?? new List<string>());
} }
/// <summary> /// <summary>
/// Insert path extensions to the front of the PATH environment variable /// Insert path extensions to the front of the PATH environment variable
/// </summary> /// </summary>
public void InsertPathExtensions() public void InsertPathExtensions()
{ {
if (Settings.PathExtensions == null) return; if (Settings.PathExtensions == null)
return;
var toInsert = GetPathExtensionsAsString(); var toInsert = GetPathExtensionsAsString();
// Append the original path, if any // Append the original path, if any
if (originalEnvPath != null) if (originalEnvPath != null)
@ -364,21 +399,23 @@ public class SettingsManager : ISettingsManager
package.Version = newVersion; package.Version = newVersion;
SaveSettings(); SaveSettings();
} }
public void SetLastUpdateCheck(InstalledPackage package) public void SetLastUpdateCheck(InstalledPackage package)
{ {
var installedPackage = Settings.InstalledPackages.First(p => p.DisplayName == package.DisplayName); var installedPackage = Settings.InstalledPackages.First(
p => p.DisplayName == package.DisplayName
);
installedPackage.LastUpdateCheck = package.LastUpdateCheck; installedPackage.LastUpdateCheck = package.LastUpdateCheck;
installedPackage.UpdateAvailable = package.UpdateAvailable; installedPackage.UpdateAvailable = package.UpdateAvailable;
SaveSettings(); SaveSettings();
} }
public List<LaunchOption> GetLaunchArgs(Guid packageId) public List<LaunchOption> GetLaunchArgs(Guid packageId)
{ {
var packageData = Settings.InstalledPackages.FirstOrDefault(x => x.Id == packageId); var packageData = Settings.InstalledPackages.FirstOrDefault(x => x.Id == packageId);
return packageData?.LaunchArgs ?? new(); return packageData?.LaunchArgs ?? new();
} }
public void SaveLaunchArgs(Guid packageId, List<LaunchOption> launchArgs) public void SaveLaunchArgs(Guid packageId, List<LaunchOption> launchArgs)
{ {
var packageData = Settings.InstalledPackages.FirstOrDefault(x => x.Id == packageId); var packageData = Settings.InstalledPackages.FirstOrDefault(x => x.Id == packageId);
@ -392,12 +429,17 @@ public class SettingsManager : ISettingsManager
packageData.LaunchArgs = toSave; packageData.LaunchArgs = toSave;
SaveSettings(); SaveSettings();
} }
public string? GetActivePackageHost() public string? GetActivePackageHost()
{ {
var package = Settings.InstalledPackages.FirstOrDefault(x => x.Id == Settings.ActiveInstalledPackageId); var package = Settings.InstalledPackages.FirstOrDefault(
if (package == null) return null; x => x.Id == Settings.ActiveInstalledPackageId
var hostOption = package.LaunchArgs?.FirstOrDefault(x => x.Name.ToLowerInvariant() == "host"); );
if (package == null)
return null;
var hostOption = package.LaunchArgs?.FirstOrDefault(
x => x.Name.ToLowerInvariant() == "host"
);
if (hostOption?.OptionValue != null) if (hostOption?.OptionValue != null)
{ {
return hostOption.OptionValue as string; return hostOption.OptionValue as string;
@ -407,9 +449,14 @@ public class SettingsManager : ISettingsManager
public string? GetActivePackagePort() public string? GetActivePackagePort()
{ {
var package = Settings.InstalledPackages.FirstOrDefault(x => x.Id == Settings.ActiveInstalledPackageId); var package = Settings.InstalledPackages.FirstOrDefault(
if (package == null) return null; x => x.Id == Settings.ActiveInstalledPackageId
var portOption = package.LaunchArgs?.FirstOrDefault(x => x.Name.ToLowerInvariant() == "port"); );
if (package == null)
return null;
var portOption = package.LaunchArgs?.FirstOrDefault(
x => x.Name.ToLowerInvariant() == "port"
);
if (portOption?.OptionValue != null) if (portOption?.OptionValue != null)
{ {
return portOption.OptionValue as string; return portOption.OptionValue as string;
@ -430,11 +477,12 @@ public class SettingsManager : ISettingsManager
} }
SaveSettings(); SaveSettings();
} }
public bool IsSharedFolderCategoryVisible(SharedFolderType type) public bool IsSharedFolderCategoryVisible(SharedFolderType type)
{ {
// False for default // False for default
if (type == 0) return false; if (type == 0)
return false;
return Settings.SharedFolderVisibleCategories?.HasFlag(type) ?? false; return Settings.SharedFolderVisibleCategories?.HasFlag(type) ?? false;
} }
@ -455,7 +503,7 @@ public class SettingsManager : ISettingsManager
public void SetEulaAccepted() public void SetEulaAccepted()
{ {
var globalSettings = new GlobalSettings {EulaAccepted = true}; var globalSettings = new GlobalSettings { EulaAccepted = true };
var json = JsonSerializer.Serialize(globalSettings); var json = JsonSerializer.Serialize(globalSettings);
File.WriteAllText(GlobalSettingsPath, json); File.WriteAllText(GlobalSettingsPath, json);
} }
@ -471,11 +519,15 @@ public class SettingsManager : ISettingsManager
var modelHashes = new HashSet<string>(); var modelHashes = new HashSet<string>();
var sharedModelDirectory = Path.Combine(LibraryDir, "Models"); var sharedModelDirectory = Path.Combine(LibraryDir, "Models");
if (!Directory.Exists(sharedModelDirectory)) return; if (!Directory.Exists(sharedModelDirectory))
return;
var connectedModelJsons = Directory.GetFiles(sharedModelDirectory, "*.cm-info.json",
SearchOption.AllDirectories); var connectedModelJsons = Directory.GetFiles(
sharedModelDirectory,
"*.cm-info.json",
SearchOption.AllDirectories
);
foreach (var jsonFile in connectedModelJsons) foreach (var jsonFile in connectedModelJsons)
{ {
var json = File.ReadAllText(jsonFile); var json = File.ReadAllText(jsonFile);
@ -488,7 +540,7 @@ public class SettingsManager : ISettingsManager
} }
Transaction(s => s.InstalledModelHashes = modelHashes); Transaction(s => s.InstalledModelHashes = modelHashes);
sw.Stop(); sw.Stop();
Logger.Info($"Indexed {modelHashes.Count} checkpoints in {sw.ElapsedMilliseconds}ms"); Logger.Info($"Indexed {modelHashes.Count} checkpoints in {sw.ElapsedMilliseconds}ms");
} }
@ -515,9 +567,10 @@ public class SettingsManager : ISettingsManager
var modifiedDefaultSerializerOptions = var modifiedDefaultSerializerOptions =
SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions(); SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions();
modifiedDefaultSerializerOptions.Converters.Add(new JsonStringEnumConverter()); modifiedDefaultSerializerOptions.Converters.Add(new JsonStringEnumConverter());
Settings = Settings = JsonSerializer.Deserialize<Settings>(
JsonSerializer.Deserialize<Settings>(settingsContent, settingsContent,
modifiedDefaultSerializerOptions)!; modifiedDefaultSerializerOptions
)!;
} }
finally finally
{ {
@ -534,12 +587,15 @@ public class SettingsManager : ISettingsManager
{ {
File.Create(SettingsPath).Close(); File.Create(SettingsPath).Close();
} }
var json = JsonSerializer.Serialize(Settings, new JsonSerializerOptions var json = JsonSerializer.Serialize(
{ Settings,
WriteIndented = true, new JsonSerializerOptions
Converters = { new JsonStringEnumConverter() } {
}); WriteIndented = true,
Converters = { new JsonStringEnumConverter() }
}
);
File.WriteAllText(SettingsPath, json); File.WriteAllText(SettingsPath, json);
} }
finally finally
@ -553,4 +609,3 @@ public class SettingsManager : ISettingsManager
return Task.Run(SaveSettings); return Task.Run(SaveSettings);
} }
} }

98
StabilityMatrix.Core/Services/TrackedDownloadService.cs

@ -14,18 +14,22 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
private readonly ILogger<TrackedDownloadService> logger; private readonly ILogger<TrackedDownloadService> logger;
private readonly IDownloadService downloadService; private readonly IDownloadService downloadService;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly ConcurrentDictionary<Guid, (TrackedDownload Download, FileStream Stream)> downloads = new(); private readonly ConcurrentDictionary<
Guid,
(TrackedDownload Download, FileStream Stream)
> downloads = new();
public IEnumerable<TrackedDownload> Downloads => downloads.Values.Select(x => x.Download); public IEnumerable<TrackedDownload> Downloads => downloads.Values.Select(x => x.Download);
/// <inheritdoc /> /// <inheritdoc />
public event EventHandler<TrackedDownload>? DownloadAdded; public event EventHandler<TrackedDownload>? DownloadAdded;
public TrackedDownloadService( public TrackedDownloadService(
ILogger<TrackedDownloadService> logger, ILogger<TrackedDownloadService> logger,
IDownloadService downloadService, IDownloadService downloadService,
ISettingsManager settingsManager) ISettingsManager settingsManager
)
{ {
this.logger = logger; this.logger = logger;
this.downloadService = downloadService; this.downloadService = downloadService;
@ -36,12 +40,13 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory); var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory);
// Ignore if not exist // Ignore if not exist
if (!downloadsDir.Exists) return; if (!downloadsDir.Exists)
return;
LoadInProgressDownloads(downloadsDir); LoadInProgressDownloads(downloadsDir);
}); });
} }
private void OnDownloadAdded(TrackedDownload download) private void OnDownloadAdded(TrackedDownload download)
{ {
DownloadAdded?.Invoke(this, download); DownloadAdded?.Invoke(this, download);
@ -55,28 +60,32 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
// Set download service // Set download service
download.SetDownloadService(downloadService); download.SetDownloadService(downloadService);
// Create json file // Create json file
var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory); var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory);
downloadsDir.Create(); downloadsDir.Create();
var jsonFile = downloadsDir.JoinFile($"{download.Id}.json"); var jsonFile = downloadsDir.JoinFile($"{download.Id}.json");
var jsonFileStream = jsonFile.Info.Open(FileMode.CreateNew, FileAccess.ReadWrite, FileShare.Read); var jsonFileStream = jsonFile.Info.Open(
FileMode.CreateNew,
FileAccess.ReadWrite,
FileShare.Read
);
// Serialize to json // Serialize to json
var json = JsonSerializer.Serialize(download); var json = JsonSerializer.Serialize(download);
jsonFileStream.Write(Encoding.UTF8.GetBytes(json)); jsonFileStream.Write(Encoding.UTF8.GetBytes(json));
jsonFileStream.Flush(); jsonFileStream.Flush();
// Add to dictionary // Add to dictionary
downloads.TryAdd(download.Id, (download, jsonFileStream)); downloads.TryAdd(download.Id, (download, jsonFileStream));
// Connect to state changed event to update json file // Connect to state changed event to update json file
AttachHandlers(download); AttachHandlers(download);
logger.LogDebug("Added download {Download}", download.FileName); logger.LogDebug("Added download {Download}", download.FileName);
OnDownloadAdded(download); OnDownloadAdded(download);
} }
/// <summary> /// <summary>
/// Update the json file for the download. /// Update the json file for the download.
/// </summary> /// </summary>
@ -85,19 +94,19 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
// Serialize to json // Serialize to json
var json = JsonSerializer.Serialize(download); var json = JsonSerializer.Serialize(download);
var jsonBytes = Encoding.UTF8.GetBytes(json); var jsonBytes = Encoding.UTF8.GetBytes(json);
// Write to file // Write to file
var (_, fs) = downloads[download.Id]; var (_, fs) = downloads[download.Id];
fs.Seek(0, SeekOrigin.Begin); fs.Seek(0, SeekOrigin.Begin);
fs.Write(jsonBytes); fs.Write(jsonBytes);
fs.Flush(); fs.Flush();
} }
private void AttachHandlers(TrackedDownload download) private void AttachHandlers(TrackedDownload download)
{ {
download.ProgressStateChanged += TrackedDownload_OnProgressStateChanged; download.ProgressStateChanged += TrackedDownload_OnProgressStateChanged;
} }
/// <summary> /// <summary>
/// Handler when the download's state changes /// Handler when the download's state changes
/// </summary> /// </summary>
@ -107,10 +116,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
return; return;
} }
// Update json file // Update json file
UpdateJsonForDownload(download); UpdateJsonForDownload(download);
// If the download is completed, remove it from the dictionary and delete the json file // If the download is completed, remove it from the dictionary and delete the json file
if (e is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled) if (e is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled)
{ {
@ -118,11 +127,13 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
downloadInfo.Item2.Dispose(); downloadInfo.Item2.Dispose();
// Delete json file // Delete json file
new DirectoryPath(settingsManager.DownloadsDirectory).JoinFile($"{download.Id}.json").Delete(); new DirectoryPath(settingsManager.DownloadsDirectory)
.JoinFile($"{download.Id}.json")
.Delete();
logger.LogDebug("Removed download {Download}", download.FileName); logger.LogDebug("Removed download {Download}", download.FileName);
} }
} }
// On successes, run the continuation action // On successes, run the continuation action
if (e == ProgressState.Success) if (e == ProgressState.Success)
{ {
@ -133,13 +144,13 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
} }
} }
} }
private void LoadInProgressDownloads(DirectoryPath downloadsDir) private void LoadInProgressDownloads(DirectoryPath downloadsDir)
{ {
logger.LogDebug("Indexing in-progress downloads at {DownloadsDir}...", downloadsDir); logger.LogDebug("Indexing in-progress downloads at {DownloadsDir}...", downloadsDir);
var jsonFiles = downloadsDir.Info.EnumerateFiles("*.json", SearchOption.TopDirectoryOnly); var jsonFiles = downloadsDir.Info.EnumerateFiles("*.json", SearchOption.TopDirectoryOnly);
// Add to dictionary, the file name is the guid // Add to dictionary, the file name is the guid
foreach (var file in jsonFiles) foreach (var file in jsonFiles)
{ {
@ -147,10 +158,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
try try
{ {
var fileStream = file.Open(FileMode.Open, FileAccess.ReadWrite, FileShare.Read); var fileStream = file.Open(FileMode.Open, FileAccess.ReadWrite, FileShare.Read);
// Deserialize json and add to dictionary // Deserialize json and add to dictionary
var download = JsonSerializer.Deserialize<TrackedDownload>(fileStream)!; var download = JsonSerializer.Deserialize<TrackedDownload>(fileStream)!;
// If the download is marked as working, pause it // If the download is marked as working, pause it
if (download.ProgressState == ProgressState.Working) if (download.ProgressState == ProgressState.Working)
{ {
@ -159,23 +170,30 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
else if (download.ProgressState != ProgressState.Inactive) else if (download.ProgressState != ProgressState.Inactive)
{ {
// If the download is not inactive, skip it // If the download is not inactive, skip it
logger.LogWarning("Skipping download {Download} with state {State}", download.FileName, download.ProgressState); logger.LogWarning(
"Skipping download {Download} with state {State}",
download.FileName,
download.ProgressState
);
fileStream.Dispose(); fileStream.Dispose();
// Delete json file // Delete json file
logger.LogDebug("Deleting json file for {Download} with unsupported state", download.FileName); logger.LogDebug(
"Deleting json file for {Download} with unsupported state",
download.FileName
);
file.Delete(); file.Delete();
continue; continue;
} }
download.SetDownloadService(downloadService); download.SetDownloadService(downloadService);
downloads.TryAdd(download.Id, (download, fileStream)); downloads.TryAdd(download.Id, (download, fileStream));
AttachHandlers(download); AttachHandlers(download);
OnDownloadAdded(download); OnDownloadAdded(download);
logger.LogDebug("Loaded in-progress download {Download}", download.FileName); logger.LogDebug("Loaded in-progress download {Download}", download.FileName);
} }
catch (Exception e) catch (Exception e)
@ -197,10 +215,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
}; };
AddDownload(download); AddDownload(download);
return download; return download;
} }
/// <summary> /// <summary>
/// Generate a new temp file name that is unique in the given directory. /// Generate a new temp file name that is unique in the given directory.
/// In format of "Unconfirmed {id}.smdownload" /// In format of "Unconfirmed {id}.smdownload"
@ -213,14 +231,14 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
for (var i = 0; i < 10; i++) for (var i = 0; i < 10; i++)
{ {
if (tempFile is {Exists: false}) if (tempFile is { Exists: false })
{ {
return tempFile.Name; return tempFile.Name;
} }
var id = Random.Shared.Next(1000000, 9999999); var id = Random.Shared.Next(1000000, 9999999);
tempFile = parentDir.JoinFile($"Unconfirmed {id}.smdownload"); tempFile = parentDir.JoinFile($"Unconfirmed {id}.smdownload");
} }
throw new Exception("Failed to generate a unique temp file name."); throw new Exception("Failed to generate a unique temp file name.");
} }
@ -241,7 +259,7 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
} }
} }
} }
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
} }

Loading…
Cancel
Save