diff --git a/CHANGELOG.md b/CHANGELOG.md index 557ed5de..5daa2ae2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2 - Update will occur the next time the package is updated, or on a fresh install - Note: CUDA 12.1 is only available on Maxwell (GTX 900 series) and newer GPUs - Improved Model Browser download stability with automatic retries for download errors +- Optimized page navigation and syntax formatting configurations to improve startup time ### Fixed - Fixed crash when clicking Inference gallery image after the image is deleted externally in file explorer - Fixed Inference popup Install button not working on One-Click Installer diff --git a/StabilityMatrix.Avalonia/App.axaml.cs b/StabilityMatrix.Avalonia/App.axaml.cs index b0cd7885..40ae18ba 100644 --- a/StabilityMatrix.Avalonia/App.axaml.cs +++ b/StabilityMatrix.Avalonia/App.axaml.cs @@ -383,11 +383,7 @@ public sealed class App : Application services.AddSingleton(); } - if (Design.IsDesignMode) - { - services.AddSingleton(); - } - else + if (!Design.IsDesignMode) { services.AddSingleton(); services.AddSingleton(p => p.GetRequiredService()); diff --git a/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml b/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml index 3fe4071e..6fb6e844 100644 --- a/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml +++ b/StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml @@ -13,11 +13,6 @@ - - - + + diff --git a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs index 38300cff..f5a79dcb 100644 --- a/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/PromptCard.axaml.cs @@ -6,6 +6,7 @@ using AvaloniaEdit; using AvaloniaEdit.Editing; using AvaloniaEdit.Utils; using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Core.Attributes; namespace StabilityMatrix.Avalonia.Controls; @@ -33,7 +34,7 @@ public class PromptCard : TemplatedControl { if (editor is not null) { - TextEditorConfigs.ConfigForPrompt(editor); + TextEditorConfigs.Configure(editor, TextEditorPreset.Prompt); editor.TextArea.Margin = new Thickness(0, 0, 4, 0); if (editor.TextArea.ActiveInputHandler is TextAreaInputHandler inputHandler) diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs index 189a4f0b..62519334 100644 --- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs +++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs @@ -9,6 +9,8 @@ using AvaloniaEdit.Utils; using DynamicData; using DynamicData.Binding; using Microsoft.Extensions.DependencyInjection; +using NSubstitute; +using NSubstitute.ReturnsExtensions; using StabilityMatrix.Avalonia.Controls.CodeCompletion; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.TagCompletion; @@ -109,17 +111,18 @@ public static class DesignData // Mock services services - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) .AddSingleton() - .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton() - .AddSingleton(); + .AddSingleton(); // Placeholder services that nobody should need during design time services diff --git a/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs b/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs deleted file mode 100644 index 6bffe945..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockApiFactory.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using StabilityMatrix.Core.Api; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockApiFactory : IApiFactory -{ - public T CreateRefitClient(Uri baseAddress) - { - throw new NotImplementedException(); - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs b/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs deleted file mode 100644 index 79851cd7..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockDiscordRichPresenceService.cs +++ /dev/null @@ -1,18 +0,0 @@ -using System; -using StabilityMatrix.Avalonia.Services; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockDiscordRichPresenceService : IDiscordRichPresenceService -{ - /// - public void Dispose() - { - GC.SuppressFinalize(this); - } - - /// - public void UpdateState() - { - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs deleted file mode 100644 index b1e08d8b..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockDownloadService.cs +++ /dev/null @@ -1,50 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using StabilityMatrix.Core.Models.Progress; -using StabilityMatrix.Core.Services; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockDownloadService : IDownloadService -{ - public Task DownloadToFileAsync( - string downloadUrl, - string downloadPath, - IProgress? progress = null, - string? httpClientName = null, - CancellationToken cancellationToken = default - ) - { - return Task.CompletedTask; - } - - /// - public Task ResumeDownloadToFileAsync( - string downloadUrl, - string downloadPath, - long existingFileSize, - IProgress? progress = null, - string? httpClientName = null, - CancellationToken cancellationToken = default - ) - { - return Task.CompletedTask; - } - - /// - public Task GetFileSizeAsync( - string downloadUrl, - string? httpClientName = null, - CancellationToken cancellationToken = default - ) - { - return Task.FromResult(0L); - } - - public Task GetImageStreamFromUrl(string url) - { - return Task.FromResult(new MemoryStream(new byte[24]) as Stream)!; - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs b/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs deleted file mode 100644 index 452af8c4..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockHttpClientFactory.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Net.Http; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockHttpClientFactory : IHttpClientFactory -{ - public HttpClient CreateClient(string name) - { - throw new NotImplementedException(); - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs b/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs deleted file mode 100644 index a8f2996e..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using LiteDB.Async; -using StabilityMatrix.Core.Database; -using StabilityMatrix.Core.Models.Api; -using StabilityMatrix.Core.Models.Database; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockLiteDbContext : ILiteDbContext -{ - public LiteDatabaseAsync Database => throw new NotImplementedException(); - public ILiteCollectionAsync CivitModels => throw new NotImplementedException(); - public ILiteCollectionAsync CivitModelVersions => - throw new NotImplementedException(); - public ILiteCollectionAsync CivitModelQueryCache => - throw new NotImplementedException(); - public ILiteCollectionAsync LocalModelFiles => - throw new NotImplementedException(); - public ILiteCollectionAsync InferenceProjects => - throw new NotImplementedException(); - public ILiteCollectionAsync LocalImageFiles => - throw new NotImplementedException(); - - public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync( - string hashBlake3 - ) - { - return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null)); - } - - public Task UpsertCivitModelAsync(CivitModel civitModel) - { - return Task.FromResult(true); - } - - public Task UpsertCivitModelAsync(IEnumerable civitModels) - { - return Task.FromResult(true); - } - - public Task UpsertCivitModelQueryCacheEntryAsync(CivitModelQueryCacheEntry entry) - { - return Task.FromResult(true); - } - - public Task GetGithubCacheEntry(string cacheKey) - { - return Task.FromResult(null); - } - - public Task UpsertGithubCacheEntry(GithubCacheEntry cacheEntry) - { - return Task.FromResult(true); - } - - public void Dispose() - { - GC.SuppressFinalize(this); - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs b/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs deleted file mode 100644 index 4fdde7ad..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockNotificationService.cs +++ /dev/null @@ -1,47 +0,0 @@ -using System; -using System.Threading.Tasks; -using Avalonia; -using Avalonia.Controls.Notifications; -using StabilityMatrix.Avalonia.Services; -using StabilityMatrix.Core.Models; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockNotificationService : INotificationService -{ - public void Initialize(Visual? visual, - NotificationPosition position = NotificationPosition.BottomRight, int maxItems = 3) - { - } - - public void Show(INotification notification) - { - } - - public Task> TryAsync(Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error) - { - return Task.FromResult(new TaskResult(default!)); - } - - public Task> TryAsync(Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error) - { - return Task.FromResult(new TaskResult(true)); - } - - public void Show( - string title, - string message, - NotificationType appearance = NotificationType.Information, - TimeSpan? expiration = null) - { - } - - public void ShowPersistent( - string title, - string message, - NotificationType appearance = NotificationType.Information) - { - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs b/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs deleted file mode 100644 index 35c21160..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockSharedFolders.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System.Threading.Tasks; -using StabilityMatrix.Core.Helper; -using StabilityMatrix.Core.Models.FileInterfaces; -using StabilityMatrix.Core.Models.Packages; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockSharedFolders : ISharedFolders -{ - public void SetupLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory) - { - } - - public Task UpdateLinksForPackage(BasePackage basePackage, DirectoryPath installDirectory) - { - return Task.CompletedTask; - } - - public void RemoveLinksForAllPackages() - { - } - - public void SetupSharedModelFolders() - { - } -} diff --git a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs b/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs deleted file mode 100644 index 4522bd2d..00000000 --- a/StabilityMatrix.Avalonia/DesignData/MockTrackedDownloadService.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; -using System.Collections.Generic; -using StabilityMatrix.Core.Models; -using StabilityMatrix.Core.Models.FileInterfaces; -using StabilityMatrix.Core.Services; - -namespace StabilityMatrix.Avalonia.DesignData; - -public class MockTrackedDownloadService : ITrackedDownloadService -{ - /// - public IEnumerable Downloads => Array.Empty(); - - /// - public event EventHandler? DownloadAdded; - - /// - public TrackedDownload NewDownload(Uri downloadUrl, FilePath downloadPath) - { - throw new NotImplementedException(); - } -} diff --git a/StabilityMatrix.Avalonia/DialogHelper.cs b/StabilityMatrix.Avalonia/DialogHelper.cs index b244afe3..92f7dbc8 100644 --- a/StabilityMatrix.Avalonia/DialogHelper.cs +++ b/StabilityMatrix.Avalonia/DialogHelper.cs @@ -424,7 +424,7 @@ public static class DialogHelper AllowScrollBelowDocument = false } }; - TextEditorConfigs.ConfigForPrompt(textEditor); + TextEditorConfigs.Configure(textEditor, TextEditorPreset.Prompt); textEditor.Document.Text = errorLineFormatted; textEditor.TextArea.Caret.Offset = textEditor.Document.Lines[0].EndOffset; diff --git a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs index e04f4f46..81691e93 100644 --- a/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs +++ b/StabilityMatrix.Avalonia/Extensions/ComfyNodeBuilderExtensions.cs @@ -184,7 +184,7 @@ public static class ComfyNodeBuilderExtensions var checkpointLoader = builder.Nodes.AddNamedNode( ComfyNodeBuilder.CheckpointLoaderSimple( "Refiner_CheckpointLoader", - modelCardViewModel.SelectedRefiner?.FileName + modelCardViewModel.SelectedRefiner?.RelativePath ?? throw new NullReferenceException("Model not selected") ) ); diff --git a/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs b/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs index 95543d51..83a744eb 100644 --- a/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs +++ b/StabilityMatrix.Avalonia/Helpers/TextEditorConfigs.cs @@ -1,4 +1,5 @@ -using System.IO; +using System; +using System.IO; using Avalonia.Media; using AvaloniaEdit; using AvaloniaEdit.TextMate; @@ -17,13 +18,22 @@ public static class TextEditorConfigs { public static void Configure(TextEditor editor, TextEditorPreset preset) { - if (preset == TextEditorPreset.Prompt) + switch (preset) { - ConfigForPrompt(editor); + case TextEditorPreset.Prompt: + ConfigForPrompt(editor); + break; + case TextEditorPreset.Console: + ConfigForConsole(editor); + break; + case TextEditorPreset.None: + break; + default: + throw new ArgumentOutOfRangeException(nameof(preset), preset, null); } } - public static void ConfigForPrompt(TextEditor editor) + private static void ConfigForPrompt(TextEditor editor) { const ThemeName themeName = ThemeName.DimmedMonokai; var registryOptions = new RegistryOptions(themeName); @@ -57,6 +67,25 @@ public static class TextEditorConfigs installation.SetTheme(theme); } + private static void ConfigForConsole(TextEditor editor) + { + var registryOptions = new RegistryOptions(ThemeName.DarkPlus); + + // Config hyperlinks + editor.TextArea.Options.EnableHyperlinks = true; + editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false; + editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral; + + var textMate = editor.InstallTextMate(registryOptions); + var scope = registryOptions.GetScopeByLanguageId("log"); + + if (scope is null) + throw new InvalidOperationException("Scope is null"); + + textMate.SetGrammar(scope); + textMate.SetTheme(registryOptions.LoadTheme(ThemeName.DarkPlus)); + } + private static IRawTheme GetThemeFromStream(Stream stream) { using var reader = new StreamReader(stream); diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs index 9210adc0..3b17284b 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatPart.cs @@ -1,16 +1,7 @@ using System; -using System.Runtime.InteropServices; -using CSharpDiscriminatedUnion.Attributes; +using OneOf; namespace StabilityMatrix.Avalonia.Models.Inference; -[GenerateDiscriminatedUnion(CaseFactoryPrefix = "From")] -[StructLayout(LayoutKind.Auto)] -public readonly partial struct FileNameFormatPart -{ - [StructCase("Constant", isDefaultValue: true)] - private readonly string constant; - - [StructCase("Substitution")] - private readonly Func substitution; -} +[GenerateOneOf] +public partial class FileNameFormatPart : OneOfBase> { } diff --git a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs index e8c13f79..ff6905fd 100644 --- a/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs +++ b/StabilityMatrix.Avalonia/Models/Inference/FileNameFormatProvider.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.ComponentModel.DataAnnotations; using System.Diagnostics.Contracts; +using System.IO; using System.Linq; using System.Text.RegularExpressions; using Avalonia.Data; @@ -26,7 +27,10 @@ public partial class FileNameFormatProvider { "seed", () => GenerationParameters?.Seed.ToString() }, { "prompt", () => GenerationParameters?.PositivePrompt }, { "negative_prompt", () => GenerationParameters?.NegativePrompt }, - { "model_name", () => GenerationParameters?.ModelName }, + { + "model_name", + () => Path.GetFileNameWithoutExtension(GenerationParameters?.ModelName) + }, { "model_hash", () => GenerationParameters?.ModelHash }, { "width", () => GenerationParameters?.Width.ToString() }, { "height", () => GenerationParameters?.Height.ToString() }, @@ -84,7 +88,7 @@ public partial class FileNameFormatProvider if (result.Index != currentIndex) { var constant = template[currentIndex..result.Index]; - parts.Add(FileNameFormatPart.FromConstant(constant)); + parts.Add(constant); currentIndex += constant.Length; } @@ -97,30 +101,32 @@ public partial class FileNameFormatProvider if (slice is not null) { parts.Add( - FileNameFormatPart.FromSubstitution(() => - { - var value = substitution(); - if (value is null) - return null; - - if (slice.End is null) + (FileNameFormatPart)( + () => { - value = value[(slice.Start ?? 0)..]; + var value = substitution(); + if (value is null) + return null; + + if (slice.End is null) + { + value = value[(slice.Start ?? 0)..]; + } + else + { + var length = + Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0); + value = value.Substring(slice.Start ?? 0, length); + } + + return value; } - else - { - var length = - Math.Min(value.Length, slice.End.Value) - (slice.Start ?? 0); - value = value.Substring(slice.Start ?? 0, length); - } - - return value; - }) + ) ); } else { - parts.Add(FileNameFormatPart.FromSubstitution(substitution)); + parts.Add(substitution); } currentIndex += result.Length; @@ -130,7 +136,7 @@ public partial class FileNameFormatProvider if (currentIndex != template.Length) { var constant = template[currentIndex..]; - parts.Add(FileNameFormatPart.FromConstant(constant)); + parts.Add(constant); } return parts; diff --git a/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs b/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs index 3caf0b86..50416643 100644 --- a/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs +++ b/StabilityMatrix.Avalonia/Models/TextEditorPreset.cs @@ -3,5 +3,6 @@ public enum TextEditorPreset { None, - Prompt + Prompt, + Console } diff --git a/StabilityMatrix.Avalonia/Services/INotificationService.cs b/StabilityMatrix.Avalonia/Services/INotificationService.cs index 77a4d9d2..d3cc7e77 100644 --- a/StabilityMatrix.Avalonia/Services/INotificationService.cs +++ b/StabilityMatrix.Avalonia/Services/INotificationService.cs @@ -2,6 +2,8 @@ using System.Threading.Tasks; using Avalonia; using Avalonia.Controls.Notifications; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.Services; @@ -11,7 +13,8 @@ public interface INotificationService public void Initialize( Visual? visual, NotificationPosition position = NotificationPosition.BottomRight, - int maxItems = 3); + int maxItems = 3 + ); public void Show(INotification notification); @@ -26,7 +29,8 @@ public interface INotificationService Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error); + NotificationType appearance = NotificationType.Error + ); /// /// Attempt to run the given void task, showing a generic error notification if it fails. @@ -40,16 +44,18 @@ public interface INotificationService Task task, string title = "Error", string? message = null, - NotificationType appearance = NotificationType.Error); + NotificationType appearance = NotificationType.Error + ); /// /// Show a notification with the given parameters. /// void Show( - string title, + string title, string message, NotificationType appearance = NotificationType.Information, - TimeSpan? expiration = null); + TimeSpan? expiration = null + ); /// /// Show a notification that will not auto-dismiss. @@ -60,5 +66,15 @@ public interface INotificationService void ShowPersistent( string title, string message, - NotificationType appearance = NotificationType.Information); + NotificationType appearance = NotificationType.Information + ); + + /// + /// Show a notification for a that will not auto-dismiss. + /// + void ShowPersistent( + AppException exception, + NotificationType appearance = NotificationType.Error, + LogLevel logLevel = LogLevel.Warning + ); } diff --git a/StabilityMatrix.Avalonia/Services/NotificationService.cs b/StabilityMatrix.Avalonia/Services/NotificationService.cs index 88a3d815..98611f8d 100644 --- a/StabilityMatrix.Avalonia/Services/NotificationService.cs +++ b/StabilityMatrix.Avalonia/Services/NotificationService.cs @@ -3,7 +3,9 @@ using System.Threading.Tasks; using Avalonia; using Avalonia.Controls; using Avalonia.Controls.Notifications; +using Microsoft.Extensions.Logging; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.Services; @@ -11,8 +13,14 @@ namespace StabilityMatrix.Avalonia.Services; [Singleton(typeof(INotificationService))] public class NotificationService : INotificationService { + private readonly ILogger logger; private WindowNotificationManager? notificationManager; + public NotificationService(ILogger logger) + { + this.logger = logger; + } + public void Initialize( Visual? visual, NotificationPosition position = NotificationPosition.BottomRight, @@ -52,6 +60,19 @@ public class NotificationService : INotificationService Show(new Notification(title, message, appearance, TimeSpan.Zero)); } + /// + public void ShowPersistent( + AppException exception, + NotificationType appearance = NotificationType.Warning, + LogLevel logLevel = LogLevel.Warning + ) + { + // Log exception + logger.Log(logLevel, exception, "{Message}", exception.Message); + + Show(new Notification(exception.Message, exception.Details, appearance, TimeSpan.Zero)); + } + /// public async Task> TryAsync( Task task, diff --git a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj index 82b8fde9..68eb7e14 100644 --- a/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj +++ b/StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj @@ -32,7 +32,6 @@ - @@ -53,6 +52,7 @@ + diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 5c121026..c06c438c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -251,7 +251,7 @@ public class InferenceTextToImageViewModel if (ModelCardViewModel is { IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false }) { var customVaeLoader = nodes.AddNamedNode( - ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.FileName) + ComfyNodeBuilder.VAELoader("VAELoader", ModelCardViewModel.SelectedVae.RelativePath) ); builder.Connections.BaseVAE = customVaeLoader.Output; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs index ed55236b..d69f5b80 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs @@ -31,9 +31,9 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad [ObservableProperty] private bool isVaeSelectionEnabled; - public string? SelectedModelName => SelectedModel?.FileName; + public string? SelectedModelName => SelectedModel?.RelativePath; - public string? SelectedVaeName => SelectedVae?.FileName; + public string? SelectedVaeName => SelectedVae?.RelativePath; public IInferenceClientManager ClientManager { get; } @@ -62,11 +62,11 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad SelectedModel = model.SelectedModelName is null ? null - : ClientManager.Models.FirstOrDefault(x => x.FileName == model.SelectedModelName); + : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); SelectedVae = model.SelectedVaeName is null ? HybridModelFile.Default - : ClientManager.VaeModels.FirstOrDefault(x => x.FileName == model.SelectedVaeName); + : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); } internal class ModelCardModel @@ -101,7 +101,7 @@ public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoad else { // Name matches - model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName)); + model = currentModels.FirstOrDefault(m => m.RelativePath.EndsWith(paramsModelName)); model ??= currentModels.FirstOrDefault( m => m.ShortDisplayName.StartsWith(paramsModelName) ); diff --git a/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs b/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs index 74e8d450..4ccd4227 100644 --- a/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/LaunchPageView.axaml.cs @@ -1,14 +1,14 @@ using System; using Avalonia.Controls; +using Avalonia.Controls.Primitives; using Avalonia.Interactivity; -using Avalonia.Media; using Avalonia.Threading; using AvaloniaEdit; -using AvaloniaEdit.TextMate; using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; -using TextMateSharp.Grammars; namespace StabilityMatrix.Avalonia.Views; @@ -20,25 +20,14 @@ public partial class LaunchPageView : UserControlBase public LaunchPageView() { InitializeComponent(); - var editor = this.FindControl("Console"); - if (editor is not null) - { - var options = new RegistryOptions(ThemeName.DarkPlus); - - // Config hyperlinks - editor.TextArea.Options.EnableHyperlinks = true; - editor.TextArea.Options.RequireControlModifierForHyperlinkClick = false; - editor.TextArea.TextView.LinkTextForegroundBrush = Brushes.Coral; - - var textMate = editor.InstallTextMate(options); - var scope = options.GetScopeByLanguageId("log"); + } - if (scope is null) - throw new InvalidOperationException("Scope is null"); + /// + protected override void OnApplyTemplate(TemplateAppliedEventArgs e) + { + base.OnApplyTemplate(e); - textMate.SetGrammar(scope); - textMate.SetTheme(options.LoadTheme(ThemeName.DarkPlus)); - } + TextEditorConfigs.Configure(Console, TextEditorPreset.Console); } protected override void OnUnloaded(RoutedEventArgs e) diff --git a/StabilityMatrix.Avalonia/Views/MainWindow.axaml.cs b/StabilityMatrix.Avalonia/Views/MainWindow.axaml.cs index 2e9d5f27..57fc1c04 100644 --- a/StabilityMatrix.Avalonia/Views/MainWindow.axaml.cs +++ b/StabilityMatrix.Avalonia/Views/MainWindow.axaml.cs @@ -87,14 +87,6 @@ public partial class MainWindow : AppWindowBase navigationService.SetFrame( FrameView ?? throw new NullReferenceException("Frame not found") ); - - // Navigate to first page - if (DataContext is not MainWindowViewModel vm) - { - throw new NullReferenceException("DataContext is not MainWindowViewModel"); - } - - navigationService.NavigateTo(vm.Pages[0], new DrillInNavigationTransitionInfo()); } protected override void OnOpened(EventArgs e) @@ -133,8 +125,23 @@ public partial class MainWindow : AppWindowBase loader.LoadFailed += OnImageLoadFailed; } + if (DataContext is not MainWindowViewModel vm) + return; + + // Navigate to first page + Dispatcher.UIThread.Post( + () => + navigationService.NavigateTo( + vm.Pages[0], + new BetterSlideNavigationTransition + { + Effect = SlideNavigationTransitionEffect.FromBottom + } + ) + ); + // Check show update teaching tip - if (DataContext is MainWindowViewModel { UpdateViewModel.IsUpdateAvailable: true } vm) + if (vm.UpdateViewModel.IsUpdateAvailable) { OnUpdateAvailable(this, vm.UpdateViewModel.UpdateInfo); } diff --git a/StabilityMatrix.Core/Attributes/SingletonAttribute.cs b/StabilityMatrix.Core/Attributes/SingletonAttribute.cs index f6897601..d6538a4e 100644 --- a/StabilityMatrix.Core/Attributes/SingletonAttribute.cs +++ b/StabilityMatrix.Core/Attributes/SingletonAttribute.cs @@ -1,8 +1,12 @@ -namespace StabilityMatrix.Core.Attributes; +using System.Diagnostics.CodeAnalysis; +namespace StabilityMatrix.Core.Attributes; + +[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] [AttributeUsage(AttributeTargets.Class)] public class SingletonAttribute : Attribute { + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] public Type? InterfaceType { get; init; } public SingletonAttribute() { } diff --git a/StabilityMatrix.Core/Attributes/TransientAttribute.cs b/StabilityMatrix.Core/Attributes/TransientAttribute.cs index e3d7e081..acccb775 100644 --- a/StabilityMatrix.Core/Attributes/TransientAttribute.cs +++ b/StabilityMatrix.Core/Attributes/TransientAttribute.cs @@ -1,8 +1,12 @@ -namespace StabilityMatrix.Core.Attributes; +using System.Diagnostics.CodeAnalysis; +namespace StabilityMatrix.Core.Attributes; + +[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] [AttributeUsage(AttributeTargets.Class)] public class TransientAttribute : Attribute { + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] public Type? InterfaceType { get; init; } public TransientAttribute() { } diff --git a/StabilityMatrix.Core/Exceptions/AppException.cs b/StabilityMatrix.Core/Exceptions/AppException.cs new file mode 100644 index 00000000..8c3b1f9f --- /dev/null +++ b/StabilityMatrix.Core/Exceptions/AppException.cs @@ -0,0 +1,9 @@ +namespace StabilityMatrix.Core.Exceptions; + +/// +/// Generic runtime exception with custom handling by notification service +/// +public class AppException : ApplicationException +{ + public string? Details { get; init; } +} diff --git a/StabilityMatrix.Core/Extensions/ProgressExtensions.cs b/StabilityMatrix.Core/Extensions/ProgressExtensions.cs new file mode 100644 index 00000000..bd344ad9 --- /dev/null +++ b/StabilityMatrix.Core/Extensions/ProgressExtensions.cs @@ -0,0 +1,23 @@ +using System.Diagnostics.CodeAnalysis; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Core.Extensions; + +public static class ProgressExtensions +{ + [return: NotNullIfNotNull(nameof(progress))] + public static Action? AsProcessOutputHandler( + this IProgress? progress + ) + { + return progress == null + ? null + : output => + { + progress.Report( + new ProgressReport { IsIndeterminate = true, Message = output.Text } + ); + }; + } +} diff --git a/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs b/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs index 8cab7af5..a996faf6 100644 --- a/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs +++ b/StabilityMatrix.Core/Models/FileInterfaces/DirectoryPath.cs @@ -1,4 +1,5 @@ -using System.Diagnostics.CodeAnalysis; +using System.Collections; +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; using StabilityMatrix.Core.Converters.Json; @@ -6,7 +7,7 @@ namespace StabilityMatrix.Core.Models.FileInterfaces; [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] [JsonConverter(typeof(StringJsonConverter))] -public class DirectoryPath : FileSystemPath, IPathObject +public class DirectoryPath : FileSystemPath, IPathObject, IEnumerable { private DirectoryInfo? info; @@ -133,8 +134,50 @@ public class DirectoryPath : FileSystemPath, IPathObject public FilePath JoinFile(params FilePath[] paths) => new(Path.Combine(FullPath, Path.Combine(paths.Select(path => path.FullPath).ToArray()))); + /// + /// Returns an enumerable collection of files that matches + /// a specified search pattern and search subdirectory option. + /// + public IEnumerable EnumerateFiles( + string searchPattern = "*", + SearchOption searchOption = SearchOption.TopDirectoryOnly + ) => Info.EnumerateFiles(searchPattern, searchOption).Select(file => new FilePath(file)); + + /// + /// Returns an enumerable collection of directories that matches + /// a specified search pattern and search subdirectory option. + /// + public IEnumerable EnumerateDirectories( + string searchPattern = "*", + SearchOption searchOption = SearchOption.TopDirectoryOnly + ) => + Info.EnumerateDirectories(searchPattern, searchOption) + .Select(directory => new DirectoryPath(directory)); + public override string ToString() => FullPath; + /// + public IEnumerator GetEnumerator() + { + return Info.EnumerateFileSystemInfos("*", SearchOption.TopDirectoryOnly) + .Select( + fsInfo => + fsInfo switch + { + FileInfo file => new FilePath(file), + DirectoryInfo directory => new DirectoryPath(directory), + _ => throw new InvalidOperationException("Unknown file system info type") + } + ) + .GetEnumerator(); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + // DirectoryPath + DirectoryPath = DirectoryPath public static DirectoryPath operator +(DirectoryPath path, DirectoryPath other) => new(Path.Combine(path, other.FullPath)); diff --git a/StabilityMatrix.Core/Models/HybridModelFile.cs b/StabilityMatrix.Core/Models/HybridModelFile.cs index 2f178fd0..60fd8d4f 100644 --- a/StabilityMatrix.Core/Models/HybridModelFile.cs +++ b/StabilityMatrix.Core/Models/HybridModelFile.cs @@ -30,7 +30,10 @@ public record HybridModelFile public bool IsRemote => RemoteName != null; [JsonIgnore] - public string FileName => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder; + public string RelativePath => IsRemote ? RemoteName : Local.RelativePathFromSharedFolder; + + [JsonIgnore] + public string FileName => Path.GetFileName(RelativePath); [JsonIgnore] public string ShortDisplayName @@ -47,7 +50,7 @@ public record HybridModelFile return "Default"; } - return Path.GetFileNameWithoutExtension(FileName); + return Path.GetFileNameWithoutExtension(RelativePath); } } @@ -63,7 +66,7 @@ public record HybridModelFile public string GetId() { - return $"{FileName};{IsNone};{IsDefault}"; + return $"{RelativePath};{IsNone};{IsDefault}"; } private sealed class RemoteNameLocalEqualityComparer : IEqualityComparer @@ -79,14 +82,14 @@ public record HybridModelFile if (x.GetType() != y.GetType()) return false; - return Equals(x.FileName, y.FileName) + return Equals(x.RelativePath, y.RelativePath) && x.IsNone == y.IsNone && x.IsDefault == y.IsDefault; } public int GetHashCode(HybridModelFile obj) { - return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.FileName); + return HashCode.Combine(obj.IsNone, obj.IsDefault, obj.RelativePath); } } diff --git a/StabilityMatrix.Core/Models/PackageModification/PipStep.cs b/StabilityMatrix.Core/Models/PackageModification/PipStep.cs new file mode 100644 index 00000000..e672da53 --- /dev/null +++ b/StabilityMatrix.Core/Models/PackageModification/PipStep.cs @@ -0,0 +1,42 @@ +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Models.PackageModification; + +public class PipStep : IPackageStep +{ + public required ProcessArgs Args { get; init; } + public required DirectoryPath VenvDirectory { get; init; } + + public DirectoryPath? WorkingDirectory { get; init; } + + public IReadOnlyDictionary? EnvironmentVariables { get; init; } + + /// + public string ProgressTitle => + Args switch + { + _ when Args.Contains("install") => "Installing Pip Packages", + _ when Args.Contains("uninstall") => "Uninstalling Pip Packages", + _ when Args.Contains("-U") || Args.Contains("--upgrade") => "Updating Pip Packages", + _ => "Running Pip" + }; + + /// + public async Task ExecuteAsync(IProgress? progress = null) + { + await using var venvRunner = new PyVenvRunner(VenvDirectory) + { + WorkingDirectory = WorkingDirectory, + EnvironmentVariables = EnvironmentVariables + }; + + var args = new List { "-m", "pip" }; + args.AddRange(Args.ToArray()); + + venvRunner.RunDetached(args.ToArray(), progress.AsProcessOutputHandler()); + } +} diff --git a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs index 665d68c6..44cce164 100644 --- a/StabilityMatrix.Core/Models/Packages/A3WebUI.cs +++ b/StabilityMatrix.Core/Models/Packages/A3WebUI.cs @@ -281,7 +281,13 @@ public class A3WebUI : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.1.1"), + onConsoleOutput + ) .ConfigureAwait(false); } } diff --git a/StabilityMatrix.Core/Models/Packages/BasePackage.cs b/StabilityMatrix.Core/Models/Packages/BasePackage.cs index e1bc0ae1..3558f79a 100644 --- a/StabilityMatrix.Core/Models/Packages/BasePackage.cs +++ b/StabilityMatrix.Core/Models/Packages/BasePackage.cs @@ -189,9 +189,14 @@ public abstract class BasePackage ); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsCuda, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithXFormers("==0.0.20") + .WithTorchExtraIndex("cu118"), + onConsoleOutput + ) .ConfigureAwait(false); - await venvRunner.PipInstall("xformers==0.0.20", onConsoleOutput).ConfigureAwait(false); } protected Task InstallDirectMlTorch( @@ -204,7 +209,7 @@ public abstract class BasePackage new ProgressReport(-1f, "Installing PyTorch for DirectML", isIndeterminate: true) ); - return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput); + return venvRunner.PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput); } protected Task InstallCpuTorch( @@ -217,6 +222,9 @@ public abstract class BasePackage new ProgressReport(-1f, "Installing PyTorch for CPU", isIndeterminate: true) ); - return venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu, onConsoleOutput); + return venvRunner.PipInstall( + new PipInstallArgs().WithTorch("==2.0.1").WithTorchVision(), + onConsoleOutput + ); } } diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 2fe22438..bcb39eb3 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -179,15 +179,20 @@ public class ComfyUI : BaseGitPackage break; case TorchVersion.Cuda: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsCuda121, onConsoleOutput) - .ConfigureAwait(false); - await venvRunner - .PipInstall("xformers==0.0.22.post4 --upgrade") + .PipInstall( + new PipInstallArgs() + .WithTorch("~=2.1.0") + .WithTorchVision() + .WithXFormers("==0.0.22.post4") + .AddArg("--upgrade") + .WithTorchExtraIndex("cu121"), + onConsoleOutput + ) .ConfigureAwait(false); break; case TorchVersion.DirectMl: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML, onConsoleOutput) + .PipInstall(new PipInstallArgs().WithTorchDirectML(), onConsoleOutput) .ConfigureAwait(false); break; case TorchVersion.Rocm: @@ -195,7 +200,14 @@ public class ComfyUI : BaseGitPackage break; case TorchVersion.Mps: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsNightlyCpu, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .AddArg("--pre") + .WithTorch() + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu"), + onConsoleOutput + ) .ConfigureAwait(false); break; default: @@ -465,7 +477,13 @@ public class ComfyUI : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm56, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.6"), + onConsoleOutput + ) .ConfigureAwait(false); } diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs index ebb4b722..34830a65 100644 --- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs +++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs @@ -182,7 +182,13 @@ public class InvokeAI : BaseGitPackage // For AMD, Install ROCm version case TorchVersion.Rocm: await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm542, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithExtraIndex("rocm5.4.2"), + onConsoleOutput + ) .ConfigureAwait(false); Logger.Info("Starting InvokeAI install (ROCm)..."); pipCommandArgs = diff --git a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs index 5eb3461f..ec8e0052 100644 --- a/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs +++ b/StabilityMatrix.Core/Models/Packages/StableDiffusionUx.cs @@ -263,7 +263,13 @@ public class StableDiffusionUx : BaseGitPackage await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); await venvRunner - .PipInstall(PyVenvRunner.TorchPipInstallArgsRocm511, onConsoleOutput) + .PipInstall( + new PipInstallArgs() + .WithTorch("==2.0.1") + .WithTorchVision() + .WithTorchExtraIndex("rocm5.1.1"), + onConsoleOutput + ) .ConfigureAwait(false); } } diff --git a/StabilityMatrix.Core/Processes/Argument.cs b/StabilityMatrix.Core/Processes/Argument.cs new file mode 100644 index 00000000..44f2c846 --- /dev/null +++ b/StabilityMatrix.Core/Processes/Argument.cs @@ -0,0 +1,6 @@ +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +[GenerateOneOf] +public partial class Argument : OneOfBase { } diff --git a/StabilityMatrix.Core/Processes/ProcessArgs.cs b/StabilityMatrix.Core/Processes/ProcessArgs.cs new file mode 100644 index 00000000..477b5a54 --- /dev/null +++ b/StabilityMatrix.Core/Processes/ProcessArgs.cs @@ -0,0 +1,72 @@ +using System.Collections; +using System.Text.RegularExpressions; +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +/// +/// Parameter type for command line arguments +/// Implicitly converts between string and string[], +/// with no parsing if the input and output types are the same. +/// +public partial class ProcessArgs : OneOfBase, IEnumerable +{ + /// + public ProcessArgs(OneOf input) + : base(input) { } + + /// + /// Whether the argument string contains the given substring, + /// or any of the given arguments if the input is an array. + /// + public bool Contains(string arg) => Match(str => str.Contains(arg), arr => arr.Any(Contains)); + + public ProcessArgs Concat(ProcessArgs other) => + Match( + str => new ProcessArgs(string.Join(' ', str, other.ToString())), + arr => new ProcessArgs(arr.Concat(other.ToArray()).ToArray()) + ); + + public ProcessArgs Prepend(ProcessArgs other) => + Match( + str => new ProcessArgs(string.Join(' ', other.ToString(), str)), + arr => new ProcessArgs(other.ToArray().Concat(arr).ToArray()) + ); + + /// + public IEnumerator GetEnumerator() + { + return ToArray().AsEnumerable().GetEnumerator(); + } + + /// + public override string ToString() + { + return Match(str => str, arr => string.Join(' ', arr.Select(ProcessRunner.Quote))); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public string[] ToArray() => + Match( + str => ArgumentsRegex().Matches(str).Select(x => x.Value.Trim('"')).ToArray(), + arr => arr + ); + + // Implicit conversions + + public static implicit operator ProcessArgs(string input) => new(input); + + public static implicit operator ProcessArgs(string[] input) => new(input); + + public static implicit operator string(ProcessArgs input) => input.ToString(); + + public static implicit operator string[](ProcessArgs input) => input.ToArray(); + + [GeneratedRegex("""[\"].+?[\"]|[^ ]+""", RegexOptions.IgnoreCase)] + private static partial Regex ArgumentsRegex(); +} diff --git a/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs new file mode 100644 index 00000000..89649564 --- /dev/null +++ b/StabilityMatrix.Core/Processes/ProcessArgsBuilder.cs @@ -0,0 +1,75 @@ +using System.Diagnostics; +using OneOf; + +namespace StabilityMatrix.Core.Processes; + +/// +/// Builder for . +/// +public record ProcessArgsBuilder +{ + protected ProcessArgsBuilder() { } + + public ProcessArgsBuilder(params Argument[] arguments) + { + Arguments = arguments.ToList(); + } + + public List Arguments { get; init; } = new(); + + private IEnumerable ToStringArgs() + { + foreach (var argument in Arguments) + { + if (argument.IsT0) + { + yield return argument.AsT0; + } + else + { + yield return argument.AsT1.Item1; + yield return argument.AsT1.Item2; + } + } + } + + /// + public override string ToString() + { + return ToProcessArgs().ToString(); + } + + public ProcessArgs ToProcessArgs() + { + return ToStringArgs().ToArray(); + } + + public static implicit operator ProcessArgs(ProcessArgsBuilder builder) => + builder.ToProcessArgs(); +} + +public static class ProcessArgBuilderExtensions +{ + public static T AddArg(this T builder, Argument argument) + where T : ProcessArgsBuilder + { + return builder with { Arguments = builder.Arguments.Append(argument).ToList() }; + } + + public static T RemoveArgKey(this T builder, string argumentKey) + where T : ProcessArgsBuilder + { + return builder with + { + Arguments = builder.Arguments + .Where( + x => + x.Match( + stringArg => stringArg != argumentKey, + tupleArg => tupleArg.Item1 != argumentKey + ) + ) + .ToList() + }; + } +} diff --git a/StabilityMatrix.Core/Python/PipInstallArgs.cs b/StabilityMatrix.Core/Python/PipInstallArgs.cs new file mode 100644 index 00000000..d16aedf7 --- /dev/null +++ b/StabilityMatrix.Core/Python/PipInstallArgs.cs @@ -0,0 +1,31 @@ +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Core.Python; + +public record PipInstallArgs : ProcessArgsBuilder +{ + public PipInstallArgs(params Argument[] arguments) + : base(arguments) { } + + public PipInstallArgs WithTorch(string version = "") => this.AddArg($"torch{version}"); + + public PipInstallArgs WithTorchDirectML(string version = "") => + this.AddArg($"torch-directml{version}"); + + public PipInstallArgs WithTorchVision(string version = "") => + this.AddArg($"torchvision{version}"); + + public PipInstallArgs WithXFormers(string version = "") => this.AddArg($"xformers{version}"); + + public PipInstallArgs WithExtraIndex(string indexUrl) => + this.AddArg(("--extra-index-url", indexUrl)); + + public PipInstallArgs WithTorchExtraIndex(string index) => + this.AddArg(("--extra-index-url", $"https://download.pytorch.org/whl/{index}")); + + /// + public override string ToString() + { + return base.ToString(); + } +} diff --git a/StabilityMatrix.Core/Python/PyVenvRunner.cs b/StabilityMatrix.Core/Python/PyVenvRunner.cs index 6bbabf14..9b288ad5 100644 --- a/StabilityMatrix.Core/Python/PyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/PyVenvRunner.cs @@ -19,25 +19,6 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - private const string TorchPipInstallArgs = "torch==2.0.1 torchvision"; - - public const string TorchPipInstallArgsCuda = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/cu118"; - public const string TorchPipInstallArgsCuda121 = - "torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121"; - public const string TorchPipInstallArgsCpu = TorchPipInstallArgs; - public const string TorchPipInstallArgsDirectML = "torch-directml"; - - public const string TorchPipInstallArgsRocm511 = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.1.1"; - public const string TorchPipInstallArgsRocm542 = - $"{TorchPipInstallArgs} --extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; - public const string TorchPipInstallArgsRocm56 = - $"{TorchPipInstallArgs} --index-url https://download.pytorch.org/whl/rocm5.6"; - - public const string TorchPipInstallArgsNightlyCpu = - "--pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu"; - /// /// Relative path to the site-packages folder from the venv root. /// This is platform specific. @@ -216,7 +197,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable /// Run a pip install command. Waits for the process to exit. /// workingDirectory defaults to RootPath. /// - public async Task PipInstall(string args, Action? outputDataReceived = null) + public async Task PipInstall(ProcessArgs args, Action? outputDataReceived = null) { if (!File.Exists(PipPath)) { @@ -236,7 +217,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable }); SetPyvenvCfg(PyRunner.PythonDir); - RunDetached($"-m pip install {args}", outputAction); + RunDetached(args.Prepend("-m pip install"), outputAction); await Process.WaitForExitAsync().ConfigureAwait(false); // Check return code @@ -349,7 +330,7 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable /// Run a command using the venv Python executable and return the result. /// /// Arguments to pass to the Python executable. - public async Task Run(string arguments) + public async Task Run(ProcessArgs arguments) { // Record output for errors var output = new StringBuilder(); @@ -381,12 +362,14 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable [MemberNotNull(nameof(Process))] public void RunDetached( - string arguments, + ProcessArgs args, Action? outputDataReceived, Action? onExit = null, bool unbuffered = true ) { + var arguments = args.ToString(); + if (!PythonPath.Exists) { throw new FileNotFoundException("Venv python not found", PythonPath); @@ -395,10 +378,10 @@ public class PyVenvRunner : IDisposable, IAsyncDisposable Logger.Info( "Launching venv process [{PythonPath}] " - + "in working directory [{WorkingDirectory}] with args {arguments.ToRepr()}", + + "in working directory [{WorkingDirectory}] with args {Arguments}", PythonPath, WorkingDirectory, - arguments.ToRepr() + arguments ); var filteredOutput = diff --git a/StabilityMatrix.Core/StabilityMatrix.Core.csproj b/StabilityMatrix.Core/StabilityMatrix.Core.csproj index 1a78f975..70900a96 100644 --- a/StabilityMatrix.Core/StabilityMatrix.Core.csproj +++ b/StabilityMatrix.Core/StabilityMatrix.Core.csproj @@ -32,6 +32,8 @@ + + diff --git a/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs new file mode 100644 index 00000000..9bb09eab --- /dev/null +++ b/StabilityMatrix.Tests/Core/PipInstallArgsTests.cs @@ -0,0 +1,61 @@ +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class PipInstallArgsTests +{ + [TestMethod] + public void TestGetTorch() + { + // Arrange + const string version = "==2.1.0"; + + // Act + var args = new PipInstallArgs().WithTorch(version).ToProcessArgs().ToString(); + + // Assert + Assert.AreEqual("torch==2.1.0", args); + } + + [TestMethod] + public void TestGetTorchWithExtraIndex() + { + // Arrange + const string version = ">=2.0.0"; + const string index = "cu118"; + + // Act + var args = new PipInstallArgs() + .WithTorch(version) + .WithTorchVision() + .WithTorchExtraIndex(index) + .ToProcessArgs() + .ToString(); + + // Assert + Assert.AreEqual( + "torch>=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu118", + args + ); + } + + [TestMethod] + public void TestGetTorchWithMoreStuff() + { + // Act + var args = new PipInstallArgs() + .AddArg("--pre") + .WithTorch("~=2.0.0") + .WithTorchVision() + .WithTorchExtraIndex("nightly/cpu") + .ToString(); + + // Assert + Assert.AreEqual( + "--pre torch~=2.0.0 torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu", + args + ); + } +} diff --git a/StabilityMatrix.Tests/Models/ProcessArgsTests.cs b/StabilityMatrix.Tests/Models/ProcessArgsTests.cs new file mode 100644 index 00000000..ae3281d9 --- /dev/null +++ b/StabilityMatrix.Tests/Models/ProcessArgsTests.cs @@ -0,0 +1,43 @@ +using StabilityMatrix.Core.Processes; + +namespace StabilityMatrix.Tests.Models; + +[TestClass] +public class ProcessArgsTests +{ + [DataTestMethod] + [DataRow("pip", new[] { "pip" })] + [DataRow("pip install torch", new[] { "pip", "install", "torch" })] + [DataRow( + "pip install -r \"file spaces/here\"", + new[] { "pip", "install", "-r", "file spaces/here" } + )] + [DataRow( + "pip install -r \"file spaces\\here\"", + new[] { "pip", "install", "-r", "file spaces\\here" } + )] + public void TestStringToArray(string input, string[] expected) + { + ProcessArgs args = input; + string[] result = args; + CollectionAssert.AreEqual(expected, result); + } + + [DataTestMethod] + [DataRow(new[] { "pip" }, "pip")] + [DataRow(new[] { "pip", "install", "torch" }, "pip install torch")] + [DataRow( + new[] { "pip", "install", "-r", "file spaces/here" }, + "pip install -r \"file spaces/here\"" + )] + [DataRow( + new[] { "pip", "install", "-r", "file spaces\\here" }, + "pip install -r \"file spaces\\here\"" + )] + public void TestArrayToString(string[] input, string expected) + { + ProcessArgs args = input; + string result = args; + Assert.AreEqual(expected, result); + } +}