Browse Source

Merge pull request #315 from ionite34/plugin-management

pull/240/head
Ionite 1 year ago committed by GitHub
parent
commit
3280aa0637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      CHANGELOG.md
  2. 6
      StabilityMatrix.Avalonia/App.axaml.cs
  3. 11
      StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml
  4. 2
      StabilityMatrix.Avalonia/Controls/ModelCard.axaml
  5. 3
      StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs
  6. 19
      StabilityMatrix.Avalonia/DesignData/DesignData.cs
  7. 12
      StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs
  8. 18
      StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs
  9. 50
      StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs
  10. 12
      StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs
  11. 62
      StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs
  12. 47
      StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs
  13. 26
      StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs
  14. 22
      StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs
  15. 2
      StabilityMatrix.Avalonia/DialogHelper.cs
  16. 2
      StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs
  17. 35
      StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs
  18. 15
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs
  19. 18
      StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs
  20. 3
      StabilityMatrix.Avalonia/Models/TextEditorPreset.cs
  21. 26
      StabilityMatrix.Avalonia/Services/INotificationService.cs
  22. 21
      StabilityMatrix.Avalonia/Services/NotificationService.cs
  23. 2
      StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj
  24. 2
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  25. 10
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  26. 29
      StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs
  27. 25
      StabilityMatrix.Avalonia/Views/MainWindow.axaml.cs
  28. 6
      StabilityMatrix.Core/Attributes/SingletonAttribute.cs
  29. 6
      StabilityMatrix.Core/Attributes/TransientAttribute.cs
  30. 9
      StabilityMatrix.Core/Exceptions/AppException.cs
  31. 23
      StabilityMatrix.Core/Extensions/ProgressExtensions.cs
  32. 47
      StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs
  33. 13
      StabilityMatrix.Core/Models/HybridModelFile.cs
  34. 42
      StabilityMatrix.Core/Models/PackageModification/PipStep.cs
  35. 8
      StabilityMatrix.Core/Models/Packages/A3WebUI.cs
  36. 16
      StabilityMatrix.Core/Models/Packages/BasePackage.cs
  37. 32
      StabilityMatrix.Core/Models/Packages/ComfyUI.cs
  38. 8
      StabilityMatrix.Core/Models/Packages/InvokeAI.cs
  39. 8
      StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs
  40. 6
      StabilityMatrix.Core/Processes/Argument.cs
  41. 72
      StabilityMatrix.Core/Processes/ProcessArgs.cs
  42. 75
      StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs
  43. 31
      StabilityMatrix.Core/Python/PipInstallArgs.cs
  44. 33
      StabilityMatrix.Core/Python/PyVenvRunner.cs
  45. 2
      StabilityMatrix.Core/StabilityMatrix.Core.csproj
  46. 61
      StabilityMatrix.Tests/Core/PipInstallArgsTests.cs
  47. 43
      StabilityMatrix.Tests/Models/ProcessArgsTests.cs

1
CHANGELOG.md

@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
- Update will occur the next time the package is updated, or on a fresh install
- Note: CUDA 12.1 is only available on Maxwell (GTX 900 series) and newer GPUs
- Improved Model Browser download stability with automatic retries for download errors
- Optimized page navigation and syntax formatting configurations to improve startup time
### Fixed
- Fixed crash when clicking Inference gallery image after the image is deleted externally in file explorer
- Fixed Inference popup Install button not working on One-Click Installer

6
StabilityMatrix.Avalonia/App.axaml.cs

@ -383,11 +383,7 @@ public sealed class App : Application
services.AddSingleton<IPrerequisiteHelper, UnixPrerequisiteHelper>();
}
if (Design.IsDesignMode)
{
services.AddSingleton<ILiteDbContext, MockLiteDbContext>();
}
else
if (!Design.IsDesignMode)
{
services.AddSingleton<ILiteDbContext, LiteDbContext>();
services.AddSingleton<IDisposable>(p => p.GetRequiredService<ILiteDbContext>());

11
StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml

@ -13,11 +13,6 @@
</Grid>
</Design.PreviewWith>
<Style Selector="ListBox /template/ VirtualizingStackPanel">
<Setter Property="Orientation" Value="Horizontal" />
</Style>
<Style Selector="controls|ImageGalleryCard">
<!-- Set Defaults -->
<Setter Property="Template">
@ -31,6 +26,12 @@
<Grid RowDefinitions="*,Auto">
<Grid.Styles>
<Style Selector="ListBox /template/ VirtualizingStackPanel">
<Setter Property="Orientation" Value="Horizontal" />
</Style>
</Grid.Styles>
<!-- Main image view -->
<Border
Classes="theme-dark"

2
StabilityMatrix.Avalonia/Controls/ModelCard.axaml

@ -81,7 +81,7 @@
HorizontalAlignment="Left"
FontSize="13"
Foreground="{DynamicResource TextFillColorTertiaryBrush}"
Text="{Binding FileName}"
Text="{Binding RelativePath}"
TextWrapping="Wrap" />
</Grid>
</StackPanel>

3
StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs

@ -6,6 +6,7 @@ using AvaloniaEdit;
using AvaloniaEdit.Editing;
using AvaloniaEdit.Utils;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Core.Attributes;
namespace StabilityMatrix.Avalonia.Controls;
@ -33,7 +34,7 @@ public class PromptCard : TemplatedControl
{
if (editor is not null)
{
TextEditorConfigs.ConfigForPrompt(editor);
TextEditorConfigs.Configure(editor, TextEditorPreset.Prompt);
editor.TextArea.Margin = new Thickness(0, 0, 4, 0);
if (editor.TextArea.ActiveInputHandler is TextAreaInputHandler inputHandler)

19
StabilityMatrix.Avalonia/DesignData/DesignData.cs

@ -9,6 +9,8 @@ using AvaloniaEdit.Utils;
using DynamicData;
using DynamicData.Binding;
using Microsoft.Extensions.DependencyInjection;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.TagCompletion;
@ -109,17 +111,18 @@ public static class DesignData
// Mock services
services
.AddSingleton<INotificationService, MockNotificationService>()
.AddSingleton<ISharedFolders, MockSharedFolders>()
.AddSingleton<IDownloadService, MockDownloadService>()
.AddSingleton<IHttpClientFactory, MockHttpClientFactory>()
.AddSingleton<IApiFactory, MockApiFactory>()
.AddSingleton(Substitute.For<INotificationService>())
.AddSingleton(Substitute.For<ISharedFolders>())
.AddSingleton(Substitute.For<IDownloadService>())
.AddSingleton(Substitute.For<IHttpClientFactory>())
.AddSingleton(Substitute.For<IApiFactory>())
.AddSingleton(Substitute.For<IDiscordRichPresenceService>())
.AddSingleton(Substitute.For<ITrackedDownloadService>())
.AddSingleton(Substitute.For<ILiteDbContext>())
.AddSingleton<IInferenceClientManager, MockInferenceClientManager>()
.AddSingleton<IDiscordRichPresenceService, MockDiscordRichPresenceService>()
.AddSingleton<ICompletionProvider, MockCompletionProvider>()
.AddSingleton<IModelIndexService, MockModelIndexService>()
.AddSingleton<IImageIndexService, MockImageIndexService>()
.AddSingleton<ITrackedDownloadService, MockTrackedDownloadService>();
.AddSingleton<IImageIndexService, MockImageIndexService>();
// Placeholder services that nobody should need during design time
services

12
StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs

@ -1,12 +0,0 @@
using System;
using StabilityMatrix.Core.Api;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockApiFactory : IApiFactory
{
public T CreateRefitClient<T>(Uri baseAddress)
{
throw new NotImplementedException();
}
}

18
StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs

@ -1,18 +0,0 @@
using System;
using StabilityMatrix.Avalonia.Services;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockDiscordRichPresenceService : IDiscordRichPresenceService
{
/// <inheritdoc />
public void Dispose()
{
GC.SuppressFinalize(this);
}
/// <inheritdoc />
public void UpdateState()
{
}
}

50
StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs

@ -1,50 +0,0 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockDownloadService : IDownloadService
{
public Task DownloadToFileAsync(
string downloadUrl,
string downloadPath,
IProgress<ProgressReport>? progress = null,
string? httpClientName = null,
CancellationToken cancellationToken = default
)
{
return Task.CompletedTask;
}
/// <inheritdoc />
public Task ResumeDownloadToFileAsync(
string downloadUrl,
string downloadPath,
long existingFileSize,
IProgress<ProgressReport>? progress = null,
string? httpClientName = null,
CancellationToken cancellationToken = default
)
{
return Task.CompletedTask;
}
/// <inheritdoc />
public Task<long> GetFileSizeAsync(
string downloadUrl,
string? httpClientName = null,
CancellationToken cancellationToken = default
)
{
return Task.FromResult(0L);
}
public Task<Stream?> GetImageStreamFromUrl(string url)
{
return Task.FromResult(new MemoryStream(new byte[24]) as Stream)!;
}
}

12
StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs

@ -1,12 +0,0 @@
using System;
using System.Net.Http;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockHttpClientFactory : IHttpClientFactory
{
public HttpClient CreateClient(string name)
{
throw new NotImplementedException();
}
}

62
StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs

@ -1,62 +0,0 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using LiteDB.Async;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockLiteDbContext : ILiteDbContext
{
public LiteDatabaseAsync Database => throw new NotImplementedException();
public ILiteCollectionAsync<CivitModel> CivitModels => throw new NotImplementedException();
public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions =>
throw new NotImplementedException();
public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache =>
throw new NotImplementedException();
public ILiteCollectionAsync<LocalModelFile> LocalModelFiles =>
throw new NotImplementedException();
public ILiteCollectionAsync<InferenceProjectEntry> InferenceProjects =>
throw new NotImplementedException();
public ILiteCollectionAsync<LocalImageFile> LocalImageFiles =>
throw new NotImplementedException();
public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(
string hashBlake3
)
{
return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null));
}
public Task<bool> UpsertCivitModelAsync(CivitModel civitModel)
{
return Task.FromResult(true);
}
public Task<bool> UpsertCivitModelAsync(IEnumerable<CivitModel> civitModels)
{
return Task.FromResult(true);
}
public Task<bool> UpsertCivitModelQueryCacheEntryAsync(CivitModelQueryCacheEntry entry)
{
return Task.FromResult(true);
}
public Task<GithubCacheEntry?> GetGithubCacheEntry(string cacheKey)
{
return Task.FromResult<GithubCacheEntry?>(null);
}
public Task<bool> UpsertGithubCacheEntry(GithubCacheEntry cacheEntry)
{
return Task.FromResult(true);
}
public void Dispose()
{
GC.SuppressFinalize(this);
}
}

47
StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs

@ -1,47 +0,0 @@
using System;
using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls.Notifications;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockNotificationService : INotificationService
{
public void Initialize(Visual? visual,
NotificationPosition position = NotificationPosition.BottomRight, int maxItems = 3)
{
}
public void Show(INotification notification)
{
}
public Task<TaskResult<T>> TryAsync<T>(Task<T> task, string title = "Error", string? message = null,
NotificationType appearance = NotificationType.Error)
{
return Task.FromResult(new TaskResult<T>(default!));
}
public Task<TaskResult<bool>> TryAsync(Task task, string title = "Error", string? message = null,
NotificationType appearance = NotificationType.Error)
{
return Task.FromResult(new TaskResult<bool>(true));
}
public void Show(
string title,
string message,
NotificationType appearance = NotificationType.Information,
TimeSpan? expiration = null)
{
}
public void ShowPersistent(
string title,
string message,
NotificationType appearance = NotificationType.Information)
{
}
}

26
StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs

@ -1,26 +0,0 @@
using System.Threading.Tasks;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Packages;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockSharedFolders : ISharedFolders
{
public void SetupLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory)
{
}
public Task UpdateLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory)
{
return Task.CompletedTask;
}
public void RemoveLinksForAllPackages()
{
}
public void SetupSharedModelFolders()
{
}
}

22
StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs

@ -1,22 +0,0 @@
using System;
using System.Collections.Generic;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockTrackedDownloadService : ITrackedDownloadService
{
/// <inheritdoc />
public IEnumerable<TrackedDownload> Downloads => Array.Empty<TrackedDownload>();
/// <inheritdoc />
public event EventHandler<TrackedDownload>? DownloadAdded;
/// <inheritdoc />
public TrackedDownload NewDownload(Uri downloadUrl, FilePath downloadPath)
{
throw new NotImplementedException();
}
}

2
StabilityMatrix.Avalonia/DialogHelper.cs

@ -424,7 +424,7 @@ public static class DialogHelper
AllowScrollBelowDocument = false
}
};
TextEditorConfigs.ConfigForPrompt(textEditor);
TextEditorConfigs.Configure(textEditor, TextEditorPreset.Prompt);
textEditor.Document.Text = errorLineFormatted;
textEditor.TextArea.Caret.Offset = textEditor.Document.Lines[0].EndOffset;

2
StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs

@ -184,7 +184,7 @@ public static class ComfyNodeBuilderExtensions
var checkpointLoader = builder.Nodes.AddNamedNode(
ComfyNodeBuilder.CheckpointLoaderSimple(
"Refiner_CheckpointLoader",
modelCardViewModel.SelectedRefiner?.FileName
modelCardViewModel.SelectedRefiner?.RelativePath
?? throw new NullReferenceException("Model not selected")
)
);

35
StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs

@ -1,4 +1,5 @@
using System.IO;
using System;
using System.IO;
using Avalonia.Media;
using AvaloniaEdit;
using AvaloniaEdit.TextMate;
@ -17,13 +18,22 @@ public static class TextEditorConfigs
{
public static void Configure(TextEditor editor, TextEditorPreset preset)
{
if (preset == TextEditorPreset.Prompt)
switch (preset)
{
case TextEditorPreset.Prompt:
ConfigForPrompt(editor);
break;
case TextEditorPreset.Console:
ConfigForConsole(editor);
break;
case TextEditorPreset.None:
break;
default:
throw new ArgumentOutOfRangeException(nameof(preset), preset, null);
}
}
public static void ConfigForPrompt(TextEditor editor)
private static void ConfigForPrompt(TextEditor editor)
{
const ThemeName themeName = ThemeName.DimmedMonokai;
var registryOptions = new RegistryOptions(themeName);
@ -57,6 +67,25 @@ public static class TextEditorConfigs
installation.SetTheme(theme);
}
private static void ConfigForConsole(TextEditor editor)
{
var registryOptions = new RegistryOptions(ThemeName.DarkPlus);
// Config hyperlinks
editor.TextArea.Options.EnableHyperlinks = true;
editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false;
editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral;
var textMate = editor.InstallTextMate(registryOptions);
var scope = registryOptions.GetScopeByLanguageId("log");
if (scope is null)
throw new InvalidOperationException("Scope is null");
textMate.SetGrammar(scope);
textMate.SetTheme(registryOptions.LoadTheme(ThemeName.DarkPlus));
}
private static IRawTheme GetThemeFromStream(Stream stream)
{
using var reader = new StreamReader(stream);

15
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs

@ -1,16 +1,7 @@
using System;
using System.Runtime.InteropServices;
using CSharpDiscriminatedUnion.Attributes;
using OneOf;
namespace StabilityMatrix.Avalonia.Models.Inference;
[GenerateDiscriminatedUnion(CaseFactoryPrefix = "From")]
[StructLayout(LayoutKind.Auto)]
public readonly partial struct FileNameFormatPart
{
[StructCase("Constant", isDefaultValue: true)]
private readonly string constant;
[StructCase("Substitution")]
private readonly Func<string?> substitution;
}
[GenerateOneOf]
public partial class FileNameFormatPart : OneOfBase<string, Func<string?>> { }

18
StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs

@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.Contracts;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using Avalonia.Data;
@ -26,7 +27,10 @@ public partial class FileNameFormatProvider
{ "seed", () => GenerationParameters?.Seed.ToString() },
{ "prompt", () => GenerationParameters?.PositivePrompt },
{ "negative_prompt", () => GenerationParameters?.NegativePrompt },
{ "model_name", () => GenerationParameters?.ModelName },
{
"model_name",
() => Path.GetFileNameWithoutExtension(GenerationParameters?.ModelName)
},
{ "model_hash", () => GenerationParameters?.ModelHash },
{ "width", () => GenerationParameters?.Width.ToString() },
{ "height", () => GenerationParameters?.Height.ToString() },
@ -84,7 +88,7 @@ public partial class FileNameFormatProvider
if (result.Index != currentIndex)
{
var constant = template[currentIndex..result.Index];
parts.Add(FileNameFormatPart.FromConstant(constant));
parts.Add(constant);
currentIndex += constant.Length;
}
@ -97,7 +101,8 @@ public partial class FileNameFormatProvider
if (slice is not null)
{
parts.Add(
FileNameFormatPart.FromSubstitution(() =>
(FileNameFormatPart)(
() =>
{
var value = substitution();
if (value is null)
@ -115,12 +120,13 @@ public partial class FileNameFormatProvider
}
return value;
})
}
)
);
}
else
{
parts.Add(FileNameFormatPart.FromSubstitution(substitution));
parts.Add(substitution);
}
currentIndex += result.Length;
@ -130,7 +136,7 @@ public partial class FileNameFormatProvider
if (currentIndex != template.Length)
{
var constant = template[currentIndex..];
parts.Add(FileNameFormatPart.FromConstant(constant));
parts.Add(constant);
}
return parts;

3
StabilityMatrix.Avalonia/Models/TextEditorPreset.cs

@ -3,5 +3,6 @@
public enum TextEditorPreset
{
None,
Prompt
Prompt,
Console
}

26
StabilityMatrix.Avalonia/Services/INotificationService.cs

@ -2,6 +2,8 @@
using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls.Notifications;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Services;
@ -11,7 +13,8 @@ public interface INotificationService
public void Initialize(
Visual? visual,
NotificationPosition position = NotificationPosition.BottomRight,
int maxItems = 3);
int maxItems = 3
);
public void Show(INotification notification);
@ -26,7 +29,8 @@ public interface INotificationService
Task<T> task,
string title = "Error",
string? message = null,
NotificationType appearance = NotificationType.Error);
NotificationType appearance = NotificationType.Error
);
/// <summary>
/// Attempt to run the given void task, showing a generic error notification if it fails.
@ -40,7 +44,8 @@ public interface INotificationService
Task task,
string title = "Error",
string? message = null,
NotificationType appearance = NotificationType.Error);
NotificationType appearance = NotificationType.Error
);
/// <summary>
/// Show a notification with the given parameters.
@ -49,7 +54,8 @@ public interface INotificationService
string title,
string message,
NotificationType appearance = NotificationType.Information,
TimeSpan? expiration = null);
TimeSpan? expiration = null
);
/// <summary>
/// Show a notification that will not auto-dismiss.
@ -60,5 +66,15 @@ public interface INotificationService
void ShowPersistent(
string title,
string message,
NotificationType appearance = NotificationType.Information);
NotificationType appearance = NotificationType.Information
);
/// <summary>
/// Show a notification for a <see cref="AppException"/> that will not auto-dismiss.
/// </summary>
void ShowPersistent(
AppException exception,
NotificationType appearance = NotificationType.Error,
LogLevel logLevel = LogLevel.Warning
);
}

21
StabilityMatrix.Avalonia/Services/NotificationService.cs

@ -3,7 +3,9 @@ using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls;
using Avalonia.Controls.Notifications;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Services;
@ -11,8 +13,14 @@ namespace StabilityMatrix.Avalonia.Services;
[Singleton(typeof(INotificationService))]
public class NotificationService : INotificationService
{
private readonly ILogger<NotificationService> logger;
private WindowNotificationManager? notificationManager;
public NotificationService(ILogger<NotificationService> logger)
{
this.logger = logger;
}
public void Initialize(
Visual? visual,
NotificationPosition position = NotificationPosition.BottomRight,
@ -52,6 +60,19 @@ public class NotificationService : INotificationService
Show(new Notification(title, message, appearance, TimeSpan.Zero));
}
/// <inheritdoc />
public void ShowPersistent(
AppException exception,
NotificationType appearance = NotificationType.Warning,
LogLevel logLevel = LogLevel.Warning
)
{
// Log exception
logger.Log(logLevel, exception, "{Message}", exception.Message);
Show(new Notification(exception.Message, exception.Details, appearance, TimeSpan.Zero));
}
/// <inheritdoc />
public async Task<TaskResult<T>> TryAsync<T>(
Task<T> task,

2
StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

@ -32,7 +32,6 @@
<PackageReference Include="Avalonia.Xaml.Behaviors" Version="11.0.2" />
<PackageReference Include="AvaloniaEdit.TextMate" Version="11.0.1" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="8.2.1" />
<PackageReference Include="CSharpDiscriminatedUnion" Version="2.0.1" />
<PackageReference Include="DiscordRichPresence" Version="1.2.1.24" />
<PackageReference Include="Dock.Avalonia" Version="11.0.0.2" />
<PackageReference Include="Dock.Model.Avalonia" Version="11.0.0.2" />
@ -53,6 +52,7 @@
<PackageReference Include="Nito.AsyncEx" Version="5.1.2" />
<PackageReference Include="NLog" Version="5.2.5" />
<PackageReference Include="NLog.Extensions.Logging" Version="5.3.5" />
<PackageReference Include="NSubstitute" Version="5.1.0" />
<PackageReference Include="Polly" Version="8.0.0" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="Polly.Extensions.Http" Version="3.0.0" />

2
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -251,7 +251,7 @@ public class InferenceTextToImageViewModel
if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false })
{
var customVaeLoader = nodes.AddNamedNode(
ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.FileName)
ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath)
);
builder.Connections.BaseVAE = customVaeLoader.Output;

10
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -31,9 +31,9 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
[ObservableProperty]
private bool isVaeSelectionEnabled;
public string? SelectedModelName => SelectedModel?.FileName;
public string? SelectedModelName => SelectedModel?.RelativePath;
public string? SelectedVaeName => SelectedVae?.FileName;
public string? SelectedVaeName => SelectedVae?.RelativePath;
public IInferenceClientManager ClientManager { get; }
@ -62,11 +62,11 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
SelectedModel = model.SelectedModelName is null
? null
: ClientManager.Models.FirstOrDefault(x => x.FileName == model.SelectedModelName);
: ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName);
SelectedVae = model.SelectedVaeName is null
? HybridModelFile.Default
: ClientManager.VaeModels.FirstOrDefault(x => x.FileName == model.SelectedVaeName);
: ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName);
}
internal class ModelCardModel
@ -101,7 +101,7 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad
else
{
// Name matches
model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName));
model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(
m => m.ShortDisplayName.StartsWith(paramsModelName)
);

29
StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs

@ -1,14 +1,14 @@
using System;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using Avalonia.Interactivity;
using Avalonia.Media;
using Avalonia.Threading;
using AvaloniaEdit;
using AvaloniaEdit.TextMate;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
using TextMateSharp.Grammars;
namespace StabilityMatrix.Avalonia.Views;
@ -20,25 +20,14 @@ public partial class LaunchPageView : UserControlBase
public LaunchPageView()
{
InitializeComponent();
var editor = this.FindControl<TextEditor>("Console");
if (editor is not null)
{
var options = new RegistryOptions(ThemeName.DarkPlus);
// Config hyperlinks
editor.TextArea.Options.EnableHyperlinks = true;
editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false;
editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral;
var textMate = editor.InstallTextMate(options);
var scope = options.GetScopeByLanguageId("log");
}
if (scope is null)
throw new InvalidOperationException("Scope is null");
/// <inheritdoc />
protected override void OnApplyTemplate(TemplateAppliedEventArgs e)
{
base.OnApplyTemplate(e);
textMate.SetGrammar(scope);
textMate.SetTheme(options.LoadTheme(ThemeName.DarkPlus));
}
TextEditorConfigs.Configure(Console, TextEditorPreset.Console);
}
protected override void OnUnloaded(RoutedEventArgs e)

25
StabilityMatrix.Avalonia/Views/MainWindow.axaml.cs

@ -87,14 +87,6 @@ public partial class MainWindow : AppWindowBase
navigationService.SetFrame(
FrameView ?? throw new NullReferenceException("Frame not found")
);
// Navigate to first page
if (DataContext is not MainWindowViewModel vm)
{
throw new NullReferenceException("DataContext is not MainWindowViewModel");
}
navigationService.NavigateTo(vm.Pages[0], new DrillInNavigationTransitionInfo());
}
protected override void OnOpened(EventArgs e)
@ -133,8 +125,23 @@ public partial class MainWindow : AppWindowBase
loader.LoadFailed += OnImageLoadFailed;
}
if (DataContext is not MainWindowViewModel vm)
return;
// Navigate to first page
Dispatcher.UIThread.Post(
() =>
navigationService.NavigateTo(
vm.Pages[0],
new BetterSlideNavigationTransition
{
Effect = SlideNavigationTransitionEffect.FromBottom
}
)
);
// Check show update teaching tip
if (DataContext is MainWindowViewModel { UpdateViewModel.IsUpdateAvailable: true } vm)
if (vm.UpdateViewModel.IsUpdateAvailable)
{
OnUpdateAvailable(this, vm.UpdateViewModel.UpdateInfo);
}

6
StabilityMatrix.Core/Attributes/SingletonAttribute.cs

@ -1,8 +1,12 @@
namespace StabilityMatrix.Core.Attributes;
using System.Diagnostics.CodeAnalysis;
namespace StabilityMatrix.Core.Attributes;
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)]
[AttributeUsage(AttributeTargets.Class)]
public class SingletonAttribute : Attribute
{
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)]
public Type? InterfaceType { get; init; }
public SingletonAttribute() { }

6
StabilityMatrix.Core/Attributes/TransientAttribute.cs

@ -1,8 +1,12 @@
namespace StabilityMatrix.Core.Attributes;
using System.Diagnostics.CodeAnalysis;
namespace StabilityMatrix.Core.Attributes;
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)]
[AttributeUsage(AttributeTargets.Class)]
public class TransientAttribute : Attribute
{
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)]
public Type? InterfaceType { get; init; }
public TransientAttribute() { }

9
StabilityMatrix.Core/Exceptions/AppException.cs

@ -0,0 +1,9 @@
namespace StabilityMatrix.Core.Exceptions;
/// <summary>
/// Generic runtime exception with custom handling by notification service
/// </summary>
public class AppException : ApplicationException
{
public string? Details { get; init; }
}

23
StabilityMatrix.Core/Extensions/ProgressExtensions.cs

@ -0,0 +1,23 @@
using System.Diagnostics.CodeAnalysis;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Extensions;
public static class ProgressExtensions
{
[return: NotNullIfNotNull(nameof(progress))]
public static Action<ProcessOutput>? AsProcessOutputHandler(
this IProgress<ProgressReport>? progress
)
{
return progress == null
? null
: output =>
{
progress.Report(
new ProgressReport { IsIndeterminate = true, Message = output.Text }
);
};
}
}

47
StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs

@ -1,4 +1,5 @@
using System.Diagnostics.CodeAnalysis;
using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
using StabilityMatrix.Core.Converters.Json;
@ -6,7 +7,7 @@ namespace StabilityMatrix.Core.Models.FileInterfaces;
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[JsonConverter(typeof(StringJsonConverter<DirectoryPath>))]
public class DirectoryPath : FileSystemPath, IPathObject
public class DirectoryPath : FileSystemPath, IPathObject, IEnumerable<FileSystemPath>
{
private DirectoryInfo? info;
@ -133,8 +134,50 @@ public class DirectoryPath : FileSystemPath, IPathObject
public FilePath JoinFile(params FilePath[] paths) =>
new(Path.Combine(FullPath, Path.Combine(paths.Select(path => path.FullPath).ToArray())));
/// <summary>
/// Returns an enumerable collection of files that matches
/// a specified search pattern and search subdirectory option.
/// </summary>
public IEnumerable<FilePath> EnumerateFiles(
string searchPattern = "*",
SearchOption searchOption = SearchOption.TopDirectoryOnly
) => Info.EnumerateFiles(searchPattern, searchOption).Select(file => new FilePath(file));
/// <summary>
/// Returns an enumerable collection of directories that matches
/// a specified search pattern and search subdirectory option.
/// </summary>
public IEnumerable<DirectoryPath> EnumerateDirectories(
string searchPattern = "*",
SearchOption searchOption = SearchOption.TopDirectoryOnly
) =>
Info.EnumerateDirectories(searchPattern, searchOption)
.Select(directory => new DirectoryPath(directory));
public override string ToString() => FullPath;
/// <inheritdoc />
public IEnumerator<FileSystemPath> GetEnumerator()
{
return Info.EnumerateFileSystemInfos("*", SearchOption.TopDirectoryOnly)
.Select<FileSystemInfo, FileSystemPath>(
fsInfo =>
fsInfo switch
{
FileInfo file => new FilePath(file),
DirectoryInfo directory => new DirectoryPath(directory),
_ => throw new InvalidOperationException("Unknown file system info type")
}
)
.GetEnumerator();
}
/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
// DirectoryPath + DirectoryPath = DirectoryPath
public static DirectoryPath operator +(DirectoryPath path, DirectoryPath other) =>
new(Path.Combine(path, other.FullPath));

13
StabilityMatrix.Core/Models/HybridModelFile.cs

@ -30,7 +30,10 @@ public record HybridModelFile
public bool IsRemote => RemoteName != null;
[JsonIgnore]
public string FileName => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder;
public string RelativePath => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder;
[JsonIgnore]
public string FileName => Path.GetFileName(RelativePath);
[JsonIgnore]
public string ShortDisplayName
@ -47,7 +50,7 @@ public record HybridModelFile
return "Default";
}
return Path.GetFileNameWithoutExtension(FileName);
return Path.GetFileNameWithoutExtension(RelativePath);
}
}
@ -63,7 +66,7 @@ public record HybridModelFile
public string GetId()
{
return $"{FileName};{IsNone};{IsDefault}";
return $"{RelativePath};{IsNone};{IsDefault}";
}
private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer<HybridModelFile>
@ -79,14 +82,14 @@ public record HybridModelFile
if (x.GetType() != y.GetType())
return false;
return Equals(x.FileName, y.FileName)
return Equals(x.RelativePath, y.RelativePath)
&& x.IsNone == y.IsNone
&& x.IsDefault == y.IsDefault;
}
public int GetHashCode(HybridModelFile obj)
{
return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.FileName);
return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.RelativePath);
}
}

42
StabilityMatrix.Core/Models/PackageModification/PipStep.cs

@ -0,0 +1,42 @@
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
namespace StabilityMatrix.Core.Models.PackageModification;
public class PipStep : IPackageStep
{
public required ProcessArgs Args { get; init; }
public required DirectoryPath VenvDirectory { get; init; }
public DirectoryPath? WorkingDirectory { get; init; }
public IReadOnlyDictionary<string, string>? EnvironmentVariables { get; init; }
/// <inheritdoc />
public string ProgressTitle =>
Args switch
{
_ when Args.Contains("install") => "Installing Pip Packages",
_ when Args.Contains("uninstall") => "Uninstalling Pip Packages",
_ when Args.Contains("-U") || Args.Contains("--upgrade") => "Updating Pip Packages",
_ => "Running Pip"
};
/// <inheritdoc />
public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null)
{
await using var venvRunner = new PyVenvRunner(VenvDirectory)
{
WorkingDirectory = WorkingDirectory,
EnvironmentVariables = EnvironmentVariables
};
var args = new List<string> { "-m", "pip" };
args.AddRange(Args.ToArray());
venvRunner.RunDetached(args.ToArray(), progress.AsProcessOutputHandler());
}
}

8
StabilityMatrix.Core/Models/Packages/A3WebUI.cs

@ -281,7 +281,13 @@ public class A3WebUI : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithTorchExtraIndex("rocm5.1.1"),
onConsoleOutput
)
.ConfigureAwait(false);
}
}

16
StabilityMatrix.Core/Models/Packages/BasePackage.cs

@ -189,9 +189,14 @@ public abstract class BasePackage
);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithXFormers("==0.0.20")
.WithTorchExtraIndex("cu118"),
onConsoleOutput
)
.ConfigureAwait(false);
await venvRunner.PipInstall("xformers==0.0.20", onConsoleOutput).ConfigureAwait(false);
}
protected Task InstallDirectMlTorch(
@ -204,7 +209,7 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true)
);
return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput);
return venvRunner.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput);
}
protected Task InstallCpuTorch(
@ -217,6 +222,9 @@ public abstract class BasePackage
new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true)
);
return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, onConsoleOutput);
return venvRunner.PipInstall(
new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(),
onConsoleOutput
);
}
}

32
StabilityMatrix.Core/Models/Packages/ComfyUI.cs

@ -179,15 +179,20 @@ public class ComfyUI : BaseGitPackage
break;
case TorchVersion.Cuda:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda121, onConsoleOutput)
.ConfigureAwait(false);
await venvRunner
.PipInstall("xformers==0.0.22.post4 --upgrade")
.PipInstall(
new PipInstallArgs()
.WithTorch("~=2.1.0")
.WithTorchVision()
.WithXFormers("==0.0.22.post4")
.AddArg("--upgrade")
.WithTorchExtraIndex("cu121"),
onConsoleOutput
)
.ConfigureAwait(false);
break;
case TorchVersion.DirectMl:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput)
.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput)
.ConfigureAwait(false);
break;
case TorchVersion.Rocm:
@ -195,7 +200,14 @@ public class ComfyUI : BaseGitPackage
break;
case TorchVersion.Mps:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsNightlyCpu, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.AddArg("--pre")
.WithTorch()
.WithTorchVision()
.WithTorchExtraIndex("nightly/cpu"),
onConsoleOutput
)
.ConfigureAwait(false);
break;
default:
@ -465,7 +477,13 @@ public class ComfyUI : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm56, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithTorchExtraIndex("rocm5.6"),
onConsoleOutput
)
.ConfigureAwait(false);
}

8
StabilityMatrix.Core/Models/Packages/InvokeAI.cs

@ -182,7 +182,13 @@ public class InvokeAI : BaseGitPackage
// For AMD, Install ROCm version
case TorchVersion.Rocm:
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithExtraIndex("rocm5.4.2"),
onConsoleOutput
)
.ConfigureAwait(false);
Logger.Info("Starting InvokeAI install (ROCm)...");
pipCommandArgs =

8
StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs

@ -263,7 +263,13 @@ public class StableDiffusionUx : BaseGitPackage
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
await venvRunner
.PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput)
.PipInstall(
new PipInstallArgs()
.WithTorch("==2.0.1")
.WithTorchVision()
.WithTorchExtraIndex("rocm5.1.1"),
onConsoleOutput
)
.ConfigureAwait(false);
}
}

6
StabilityMatrix.Core/Processes/Argument.cs

@ -0,0 +1,6 @@
using OneOf;
namespace StabilityMatrix.Core.Processes;
[GenerateOneOf]
public partial class Argument : OneOfBase<string, (string, string)> { }

72
StabilityMatrix.Core/Processes/ProcessArgs.cs

@ -0,0 +1,72 @@
using System.Collections;
using System.Text.RegularExpressions;
using OneOf;
namespace StabilityMatrix.Core.Processes;
/// <summary>
/// Parameter type for command line arguments
/// Implicitly converts between string and string[],
/// with no parsing if the input and output types are the same.
/// </summary>
public partial class ProcessArgs : OneOfBase<string, string[]>, IEnumerable<string>
{
/// <inheritdoc />
public ProcessArgs(OneOf<string, string[]> input)
: base(input) { }
/// <summary>
/// Whether the argument string contains the given substring,
/// or any of the given arguments if the input is an array.
/// </summary>
public bool Contains(string arg) => Match(str => str.Contains(arg), arr => arr.Any(Contains));
public ProcessArgs Concat(ProcessArgs other) =>
Match(
str => new ProcessArgs(string.Join(' ', str, other.ToString())),
arr => new ProcessArgs(arr.Concat(other.ToArray()).ToArray())
);
public ProcessArgs Prepend(ProcessArgs other) =>
Match(
str => new ProcessArgs(string.Join(' ', other.ToString(), str)),
arr => new ProcessArgs(other.ToArray().Concat(arr).ToArray())
);
/// <inheritdoc />
public IEnumerator<string> GetEnumerator()
{
return ToArray().AsEnumerable().GetEnumerator();
}
/// <inheritdoc />
public override string ToString()
{
return Match(str => str, arr => string.Join(' ', arr.Select(ProcessRunner.Quote)));
}
/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public string[] ToArray() =>
Match(
str => ArgumentsRegex().Matches(str).Select(x => x.Value.Trim('"')).ToArray(),
arr => arr
);
// Implicit conversions
public static implicit operator ProcessArgs(string input) => new(input);
public static implicit operator ProcessArgs(string[] input) => new(input);
public static implicit operator string(ProcessArgs input) => input.ToString();
public static implicit operator string[](ProcessArgs input) => input.ToArray();
[GeneratedRegex("""[\"].+?[\"]|[^ ]+""", RegexOptions.IgnoreCase)]
private static partial Regex ArgumentsRegex();
}

75
StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs

@ -0,0 +1,75 @@
using System.Diagnostics;
using OneOf;
namespace StabilityMatrix.Core.Processes;
/// <summary>
/// Builder for <see cref="ProcessArgs"/>.
/// </summary>
public record ProcessArgsBuilder
{
protected ProcessArgsBuilder() { }
public ProcessArgsBuilder(params Argument[] arguments)
{
Arguments = arguments.ToList();
}
public List<Argument> Arguments { get; init; } = new();
private IEnumerable<string> ToStringArgs()
{
foreach (var argument in Arguments)
{
if (argument.IsT0)
{
yield return argument.AsT0;
}
else
{
yield return argument.AsT1.Item1;
yield return argument.AsT1.Item2;
}
}
}
/// <inheritdoc />
public override string ToString()
{
return ToProcessArgs().ToString();
}
public ProcessArgs ToProcessArgs()
{
return ToStringArgs().ToArray();
}
public static implicit operator ProcessArgs(ProcessArgsBuilder builder) =>
builder.ToProcessArgs();
}
public static class ProcessArgBuilderExtensions
{
public static T AddArg<T>(this T builder, Argument argument)
where T : ProcessArgsBuilder
{
return builder with { Arguments = builder.Arguments.Append(argument).ToList() };
}
public static T RemoveArgKey<T>(this T builder, string argumentKey)
where T : ProcessArgsBuilder
{
return builder with
{
Arguments = builder.Arguments
.Where(
x =>
x.Match(
stringArg => stringArg != argumentKey,
tupleArg => tupleArg.Item1 != argumentKey
)
)
.ToList()
};
}
}

31
StabilityMatrix.Core/Python/PipInstallArgs.cs

@ -0,0 +1,31 @@
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Python;
public record PipInstallArgs : ProcessArgsBuilder
{
public PipInstallArgs(params Argument[] arguments)
: base(arguments) { }
public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}");
public PipInstallArgs WithTorchDirectML(string version = "") =>
this.AddArg($"torch-directml{version}");
public PipInstallArgs WithTorchVision(string version = "") =>
this.AddArg($"torchvision{version}");
public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}");
public PipInstallArgs WithExtraIndex(string indexUrl) =>
this.AddArg(("--extra-index-url", indexUrl));
public PipInstallArgs WithTorchExtraIndex(string index) =>
this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}"));
/// <inheritdoc />
public override string ToString()
{
return base.ToString();
}
}

33
StabilityMatrix.Core/Python/PyVenvRunner.cs

@ -19,25 +19,6 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private const string TorchPipInstallArgs = "torch==2.0.1 torchvision";
public const string TorchPipInstallArgsCuda =
$"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/cu118";
public const string TorchPipInstallArgsCuda121 =
"torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121";
public const string TorchPipInstallArgsCpu = TorchPipInstallArgs;
public const string TorchPipInstallArgsDirectML = "torch-directml";
public const string TorchPipInstallArgsRocm511 =
$"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.1.1";
public const string TorchPipInstallArgsRocm542 =
$"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.4.2";
public const string TorchPipInstallArgsRocm56 =
$"{TorchPipInstallArgs} --index-url https://download.pytorch.org/whl/rocm5.6";
public const string TorchPipInstallArgsNightlyCpu =
"--pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu";
/// <summary>
/// Relative path to the site-packages folder from the venv root.
/// This is platform specific.
@ -216,7 +197,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
/// Run a pip install command. Waits for the process to exit.
/// workingDirectory defaults to RootPath.
/// </summary>
public async Task PipInstall(string args, Action<ProcessOutput>? outputDataReceived = null)
public async Task PipInstall(ProcessArgs args, Action<ProcessOutput>? outputDataReceived = null)
{
if (!File.Exists(PipPath))
{
@ -236,7 +217,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
});
SetPyvenvCfg(PyRunner.PythonDir);
RunDetached($"-m pip install {args}", outputAction);
RunDetached(args.Prepend("-m pip install"), outputAction);
await Process.WaitForExitAsync().ConfigureAwait(false);
// Check return code
@ -349,7 +330,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
/// Run a command using the venv Python executable and return the result.
/// </summary>
/// <param name="arguments">Arguments to pass to the Python executable.</param>
public async Task<ProcessResult> Run(string arguments)
public async Task<ProcessResult> Run(ProcessArgs arguments)
{
// Record output for errors
var output = new StringBuilder();
@ -381,12 +362,14 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
[MemberNotNull(nameof(Process))]
public void RunDetached(
string arguments,
ProcessArgs args,
Action<ProcessOutput>? outputDataReceived,
Action<int>? onExit = null,
bool unbuffered = true
)
{
var arguments = args.ToString();
if (!PythonPath.Exists)
{
throw new FileNotFoundException("Venv python not found", PythonPath);
@ -395,10 +378,10 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable
Logger.Info(
"Launching venv process [{PythonPath}] "
+ "in working directory [{WorkingDirectory}] with args {arguments.ToRepr()}",
+ "in working directory [{WorkingDirectory}] with args {Arguments}",
PythonPath,
WorkingDirectory,
arguments.ToRepr()
arguments
);
var filteredOutput =

2
StabilityMatrix.Core/StabilityMatrix.Core.csproj

@ -32,6 +32,8 @@
<PackageReference Include="NLog.Extensions.Logging" Version="5.3.5" />
<PackageReference Include="NSec.Cryptography" Version="22.4.0" />
<PackageReference Include="Octokit" Version="8.1.1" />
<PackageReference Include="OneOf" Version="3.0.263" />
<PackageReference Include="OneOf.SourceGenerator" Version="3.0.263" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Refit" Version="7.0.0" />

61
StabilityMatrix.Tests/Core/PipInstallArgsTests.cs

@ -0,0 +1,61 @@
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
namespace StabilityMatrix.Tests.Core;
[TestClass]
public class PipInstallArgsTests
{
[TestMethod]
public void TestGetTorch()
{
// Arrange
const string version = "==2.1.0";
// Act
var args = new PipInstallArgs().WithTorch(version).ToProcessArgs().ToString();
// Assert
Assert.AreEqual("torch==2.1.0", args);
}
[TestMethod]
public void TestGetTorchWithExtraIndex()
{
// Arrange
const string version = ">=2.0.0";
const string index = "cu118";
// Act
var args = new PipInstallArgs()
.WithTorch(version)
.WithTorchVision()
.WithTorchExtraIndex(index)
.ToProcessArgs()
.ToString();
// Assert
Assert.AreEqual(
"torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118",
args
);
}
[TestMethod]
public void TestGetTorchWithMoreStuff()
{
// Act
var args = new PipInstallArgs()
.AddArg("--pre")
.WithTorch("~=2.0.0")
.WithTorchVision()
.WithTorchExtraIndex("nightly/cpu")
.ToString();
// Assert
Assert.AreEqual(
"--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu",
args
);
}
}

43
StabilityMatrix.Tests/Models/ProcessArgsTests.cs

@ -0,0 +1,43 @@
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Tests.Models;
[TestClass]
public class ProcessArgsTests
{
[DataTestMethod]
[DataRow("pip", new[] { "pip" })]
[DataRow("pip install torch", new[] { "pip", "install", "torch" })]
[DataRow(
"pip install -r \"file spaces/here\"",
new[] { "pip", "install", "-r", "file spaces/here" }
)]
[DataRow(
"pip install -r \"file spaces\\here\"",
new[] { "pip", "install", "-r", "file spaces\\here" }
)]
public void TestStringToArray(string input, string[] expected)
{
ProcessArgs args = input;
string[] result = args;
CollectionAssert.AreEqual(expected, result);
}
[DataTestMethod]
[DataRow(new[] { "pip" }, "pip")]
[DataRow(new[] { "pip", "install", "torch" }, "pip install torch")]
[DataRow(
new[] { "pip", "install", "-r", "file spaces/here" },
"pip install -r \"file spaces/here\""
)]
[DataRow(
new[] { "pip", "install", "-r", "file spaces\\here" },
"pip install -r \"file spaces\\here\""
)]
public void TestArrayToString(string[] input, string expected)
{
ProcessArgs args = input;
string result = args;
Assert.AreEqual(expected, result);
}
}
Loading…
Cancel
Save