From e2e85c6f579f70f0911a187908406751938ec88e Mon Sep 17 00:00:00 2001 From: Ionite Date: Thu, 28 Sep 2023 16:20:30 -0400 Subject: [PATCH] Add save load support for files with GenerationParameters --- .../Models/IParametersLoadableState.cs | 15 ++++ .../Base/InferenceTabViewModelBase.cs | 76 ++++++++++++++++++- .../InferenceTextToImageViewModel.cs | 28 ++++++- .../Inference/ModelCardViewModel.cs | 49 +++++++++++- .../Inference/PromptCardViewModel.cs | 20 ++++- .../Inference/SamplerCardViewModel.cs | 63 +++++++-------- 6 files changed, 211 insertions(+), 40 deletions(-) create mode 100644 StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs diff --git a/StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs b/StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs new file mode 100644 index 00000000..c06d88d2 --- /dev/null +++ b/StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs @@ -0,0 +1,15 @@ +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Avalonia.Models; + +public interface IParametersLoadableState +{ + void LoadStateFromParameters(GenerationParameters parameters); + + GenerationParameters SaveStateToParameters(GenerationParameters parameters); + + public GenerationParameters SaveStateToParameters() + { + return SaveStateToParameters(new GenerationParameters()); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs index 1ed26e39..b1586022 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; @@ -6,6 +7,7 @@ using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls; using Avalonia.Input; +using Avalonia.Platform.Storage; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; @@ -13,6 +15,8 @@ using FluentAvalonia.UI.Controls; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.FileInterfaces; @@ -144,6 +148,66 @@ public abstract partial class InferenceTabViewModelBase GC.SuppressFinalize(this); } + private bool TryLoadImageMetadata(FilePath? filePath) + { + if (filePath is not { Exists: true }) + return false; + + var metadata = ImageMetadata.GetAllFileMetadata(filePath); + + // Has SMProject metadata + if (metadata.SMProject is not null) + { + var project = JsonSerializer.Deserialize(metadata.SMProject); + + // Check project type matches + if (project?.ProjectType.ToViewModelType() == GetType() && project.State is not null) + { + LoadStateFromJsonObject(project.State); + } + else + { + return false; + } + + // Load image + if (this is IImageGalleryComponent imageGalleryComponent) + { + imageGalleryComponent.LoadImagesToGallery(new ImageSource(filePath)); + } + + return true; + } + + // Has generic metadata + if (metadata.Parameters is { } parametersString) + { + if (!GenerationParameters.TryParse(parametersString, out var parameters)) + { + return false; + } + + if (this is IParametersLoadableState paramsLoadableVm) + { + paramsLoadableVm.LoadStateFromParameters(parameters); + } + else + { + return false; + } + + // Load image + if (this is IImageGalleryComponent imageGalleryComponent) + { + imageGalleryComponent.LoadImagesToGallery(new ImageSource(filePath)); + } + + return true; + } + + return false; + } + /// public void DragOver(object? sender, DragEventArgs e) { @@ -162,10 +226,10 @@ public abstract partial class InferenceTabViewModelBase if (e.Data.GetDataFormats().Contains(DataFormats.Files)) { e.Handled = true; - e.DragEffects = DragDropEffects.None; return; } + // Other kinds - not supported e.DragEffects = DragDropEffects.None; } @@ -214,6 +278,16 @@ public abstract partial class InferenceTabViewModelBase if (e.Data.GetDataFormats().Contains(DataFormats.Files)) { e.Handled = true; + + if (e.Data.Get(DataFormats.Files) is IEnumerable files) + { + var paths = files.Select(f => f.TryGetLocalPath()).ToList(); + + if (paths.FirstOrDefault() is { } file) + { + Dispatcher.UIThread.Post(() => TryLoadImageMetadata(file)); + } + } } } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 249aa449..2606a1f4 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -25,7 +25,9 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceTextToImageView), persistent: true)] -public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase +public class InferenceTextToImageViewModel + : InferenceGenerationViewModelBase, + IParametersLoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -327,4 +329,28 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase await RunGeneration(generationArgs, cancellationToken); } + + /// + public void LoadStateFromParameters(GenerationParameters parameters) + { + PromptCardViewModel.LoadStateFromParameters(parameters); + SamplerCardViewModel.LoadStateFromParameters(parameters); + + SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed); + + ModelCardViewModel.LoadStateFromParameters(parameters); + } + + /// + public GenerationParameters SaveStateToParameters(GenerationParameters parameters) + { + parameters = PromptCardViewModel.SaveStateToParameters(parameters); + parameters = SamplerCardViewModel.SaveStateToParameters(parameters); + + parameters.Seed = (ulong)SeedCardViewModel.Seed; + + parameters = ModelCardViewModel.SaveStateToParameters(parameters); + + return parameters; + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs index d46c999b..2c28da8f 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs @@ -1,7 +1,9 @@ -using System.Linq; +using System; +using System.Linq; using System.Text.Json.Nodes; using CommunityToolkit.Mvvm.ComponentModel; using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; @@ -10,7 +12,7 @@ using StabilityMatrix.Core.Models; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(ModelCard))] -public partial class ModelCardViewModel : LoadableViewModelBase +public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState { [ObservableProperty] private HybridModelFile? selectedModel; @@ -71,4 +73,47 @@ public partial class ModelCardViewModel : LoadableViewModelBase public string? SelectedVaeName { get; init; } public bool IsVaeSelectionEnabled { get; init; } } + + /// + public void LoadStateFromParameters(GenerationParameters parameters) + { + if (parameters.ModelName is not { } paramsModelName) + return; + + var currentModels = ClientManager.Models; + + HybridModelFile? model; + + // First try hash match + if (parameters.ModelHash is not null) + { + model = currentModels.FirstOrDefault( + m => + m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256 + && sha256.StartsWith( + parameters.ModelHash, + StringComparison.InvariantCultureIgnoreCase + ) + ); + } + else + { + // Name matches + model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName)); + model ??= currentModels.FirstOrDefault( + m => m.ShortDisplayName.StartsWith(paramsModelName) + ); + } + + if (model is not null) + { + SelectedModel = model; + } + } + + /// + public GenerationParameters SaveStateToParameters(GenerationParameters parameters) + { + return parameters with { ModelName = SelectedModel?.FileName }; + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs index f4f65f91..a7b76ad0 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs @@ -22,12 +22,13 @@ using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(PromptCard))] -public partial class PromptCardViewModel : LoadableViewModelBase +public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState { private readonly IModelIndexService modelIndexService; @@ -284,4 +285,21 @@ public partial class PromptCardViewModel : LoadableViewModelBase PromptDocument.Text = model.Prompt ?? ""; NegativePromptDocument.Text = model.NegativePrompt ?? ""; } + + /// + public void LoadStateFromParameters(GenerationParameters parameters) + { + PromptDocument.Text = parameters.PositivePrompt ?? ""; + NegativePromptDocument.Text = parameters.NegativePrompt ?? ""; + } + + /// + public GenerationParameters SaveStateToParameters(GenerationParameters parameters) + { + return parameters with + { + PositivePrompt = PromptDocument.Text, + NegativePrompt = NegativePromptDocument.Text + }; + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index 17448c13..1200e013 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -1,15 +1,18 @@ -using System.Text.Json.Serialization; +using System.Linq; +using System.Text.Json.Serialization; using CommunityToolkit.Mvvm.ComponentModel; using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(SamplerCard))] -public partial class SamplerCardViewModel : LoadableViewModelBase +public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState { [ObservableProperty] private bool isRefinerStepsEnabled; @@ -61,42 +64,32 @@ public partial class SamplerCardViewModel : LoadableViewModelBase ClientManager = clientManager; } - /*/// - public override void LoadStateFromJsonObject(JsonObject state) + /// + public void LoadStateFromParameters(GenerationParameters parameters) { - var model = DeserializeModel(state); - - Steps = model.Steps; - IsDenoiseStrengthEnabled = model.IsDenoiseStrengthEnabled; - DenoiseStrength = model.DenoiseStrength; - IsCfgScaleEnabled = model.IsCfgScaleEnabled; - CfgScale = model.CfgScale; - IsDimensionsEnabled = model.IsDimensionsEnabled; - Width = model.Width; - Height = model.Height; - IsSamplerSelectionEnabled = model.IsSamplerSelectionEnabled; - SelectedSampler = model.SelectedSampler is null - ? null - : new ComfySampler(model.SelectedSampler); + Width = parameters.Width; + Height = parameters.Height; + Steps = parameters.Steps; + CfgScale = parameters.CfgScale; + + if (parameters.GetComfySamplers() is { } paramSamplers) + { + var (sampler, scheduler) = paramSamplers; + + SelectedSampler = ClientManager.Samplers.FirstOrDefault(s => s.Name == sampler.Name); + } } /// - public override JsonObject SaveStateToJsonObject() + public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { - return SerializeModel( - new SamplerCardModel - { - Steps = Steps, - IsDenoiseStrengthEnabled = IsDenoiseStrengthEnabled, - DenoiseStrength = DenoiseStrength, - IsCfgScaleEnabled = IsCfgScaleEnabled, - CfgScale = CfgScale, - IsDimensionsEnabled = IsDimensionsEnabled, - Width = Width, - Height = Height, - IsSamplerSelectionEnabled = IsSamplerSelectionEnabled, - SelectedSampler = SelectedSampler?.Name - } - ); - }*/ + return parameters with + { + Width = Width, + Height = Height, + Steps = Steps, + CfgScale = CfgScale, + Sampler = SelectedSampler?.Name + }; + } }