Browse Source

Add save load support for files with GenerationParameters

pull/165/head
Ionite 1 year ago
parent
commit
e2e85c6f57
No known key found for this signature in database
  1. 15
      StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs
  2. 76
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs
  3. 28
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  4. 49
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  5. 20
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  6. 63
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

15
StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs

@ -0,0 +1,15 @@
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Models;
public interface IParametersLoadableState
{
void LoadStateFromParameters(GenerationParameters parameters);
GenerationParameters SaveStateToParameters(GenerationParameters parameters);
public GenerationParameters SaveStateToParameters()
{
return SaveStateToParameters(new GenerationParameters());
}
}

76
StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs

@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
@ -6,6 +7,7 @@ using System.Threading.Tasks;
using AsyncAwaitBestPractices; using AsyncAwaitBestPractices;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Input; using Avalonia.Input;
using Avalonia.Platform.Storage;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
@ -13,6 +15,8 @@ using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference; using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database; using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.FileInterfaces;
@ -144,6 +148,66 @@ public abstract partial class InferenceTabViewModelBase
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
private bool TryLoadImageMetadata(FilePath? filePath)
{
if (filePath is not { Exists: true })
return false;
var metadata = ImageMetadata.GetAllFileMetadata(filePath);
// Has SMProject metadata
if (metadata.SMProject is not null)
{
var project = JsonSerializer.Deserialize<InferenceProjectDocument>(metadata.SMProject);
// Check project type matches
if (project?.ProjectType.ToViewModelType() == GetType() && project.State is not null)
{
LoadStateFromJsonObject(project.State);
}
else
{
return false;
}
// Load image
if (this is IImageGalleryComponent imageGalleryComponent)
{
imageGalleryComponent.LoadImagesToGallery(new ImageSource(filePath));
}
return true;
}
// Has generic metadata
if (metadata.Parameters is { } parametersString)
{
if (!GenerationParameters.TryParse(parametersString, out var parameters))
{
return false;
}
if (this is IParametersLoadableState paramsLoadableVm)
{
paramsLoadableVm.LoadStateFromParameters(parameters);
}
else
{
return false;
}
// Load image
if (this is IImageGalleryComponent imageGalleryComponent)
{
imageGalleryComponent.LoadImagesToGallery(new ImageSource(filePath));
}
return true;
}
return false;
}
/// <inheritdoc /> /// <inheritdoc />
public void DragOver(object? sender, DragEventArgs e) public void DragOver(object? sender, DragEventArgs e)
{ {
@ -162,10 +226,10 @@ public abstract partial class InferenceTabViewModelBase
if (e.Data.GetDataFormats().Contains(DataFormats.Files)) if (e.Data.GetDataFormats().Contains(DataFormats.Files))
{ {
e.Handled = true; e.Handled = true;
e.DragEffects = DragDropEffects.None;
return; return;
} }
// Other kinds - not supported
e.DragEffects = DragDropEffects.None; e.DragEffects = DragDropEffects.None;
} }
@ -214,6 +278,16 @@ public abstract partial class InferenceTabViewModelBase
if (e.Data.GetDataFormats().Contains(DataFormats.Files)) if (e.Data.GetDataFormats().Contains(DataFormats.Files))
{ {
e.Handled = true; e.Handled = true;
if (e.Data.Get(DataFormats.Files) is IEnumerable<IStorageItem> files)
{
var paths = files.Select(f => f.TryGetLocalPath()).ToList();
if (paths.FirstOrDefault() is { } file)
{
Dispatcher.UIThread.Post(() => TryLoadImageMetadata(file));
}
}
} }
} }
} }

28
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -25,7 +25,9 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), persistent: true)] [View(typeof(InferenceTextToImageView), persistent: true)]
public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase public class InferenceTextToImageViewModel
: InferenceGenerationViewModelBase,
IParametersLoadableState
{ {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -327,4 +329,28 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase
await RunGeneration(generationArgs, cancellationToken); await RunGeneration(generationArgs, cancellationToken);
} }
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
PromptCardViewModel.LoadStateFromParameters(parameters);
SamplerCardViewModel.LoadStateFromParameters(parameters);
SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed);
ModelCardViewModel.LoadStateFromParameters(parameters);
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
parameters = PromptCardViewModel.SaveStateToParameters(parameters);
parameters = SamplerCardViewModel.SaveStateToParameters(parameters);
parameters.Seed = (ulong)SeedCardViewModel.Seed;
parameters = ModelCardViewModel.SaveStateToParameters(parameters);
return parameters;
}
} }

49
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -1,7 +1,9 @@
using System.Linq; using System;
using System.Linq;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
@ -10,7 +12,7 @@ using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ModelCard))] [View(typeof(ModelCard))]
public partial class ModelCardViewModel : LoadableViewModelBase public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState
{ {
[ObservableProperty] [ObservableProperty]
private HybridModelFile? selectedModel; private HybridModelFile? selectedModel;
@ -71,4 +73,47 @@ public partial class ModelCardViewModel : LoadableViewModelBase
public string? SelectedVaeName { get; init; } public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; } public bool IsVaeSelectionEnabled { get; init; }
} }
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
if (parameters.ModelName is not { } paramsModelName)
return;
var currentModels = ClientManager.Models;
HybridModelFile? model;
// First try hash match
if (parameters.ModelHash is not null)
{
model = currentModels.FirstOrDefault(
m =>
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256
&& sha256.StartsWith(
parameters.ModelHash,
StringComparison.InvariantCultureIgnoreCase
)
);
}
else
{
// Name matches
model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(
m => m.ShortDisplayName.StartsWith(paramsModelName)
);
}
if (model is not null)
{
SelectedModel = model;
}
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with { ModelName = SelectedModel?.FileName };
}
} }

20
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -22,12 +22,13 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))] [View(typeof(PromptCard))]
public partial class PromptCardViewModel : LoadableViewModelBase public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState
{ {
private readonly IModelIndexService modelIndexService; private readonly IModelIndexService modelIndexService;
@ -284,4 +285,21 @@ public partial class PromptCardViewModel : LoadableViewModelBase
PromptDocument.Text = model.Prompt ?? ""; PromptDocument.Text = model.Prompt ?? "";
NegativePromptDocument.Text = model.NegativePrompt ?? ""; NegativePromptDocument.Text = model.NegativePrompt ?? "";
} }
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
PromptDocument.Text = parameters.PositivePrompt ?? "";
NegativePromptDocument.Text = parameters.NegativePrompt ?? "";
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with
{
PositivePrompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text
};
}
} }

63
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -1,15 +1,18 @@
using System.Text.Json.Serialization; using System.Linq;
using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy; using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Avalonia.ViewModels.Inference; namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SamplerCard))] [View(typeof(SamplerCard))]
public partial class SamplerCardViewModel : LoadableViewModelBase public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState
{ {
[ObservableProperty] [ObservableProperty]
private bool isRefinerStepsEnabled; private bool isRefinerStepsEnabled;
@ -61,42 +64,32 @@ public partial class SamplerCardViewModel : LoadableViewModelBase
ClientManager = clientManager; ClientManager = clientManager;
} }
/*/// <inheritdoc /> /// <inheritdoc />
public override void LoadStateFromJsonObject(JsonObject state) public void LoadStateFromParameters(GenerationParameters parameters)
{ {
var model = DeserializeModel<SamplerCardModel>(state); Width = parameters.Width;
Height = parameters.Height;
Steps = model.Steps; Steps = parameters.Steps;
IsDenoiseStrengthEnabled = model.IsDenoiseStrengthEnabled; CfgScale = parameters.CfgScale;
DenoiseStrength = model.DenoiseStrength;
IsCfgScaleEnabled = model.IsCfgScaleEnabled; if (parameters.GetComfySamplers() is { } paramSamplers)
CfgScale = model.CfgScale; {
IsDimensionsEnabled = model.IsDimensionsEnabled; var (sampler, scheduler) = paramSamplers;
Width = model.Width;
Height = model.Height; SelectedSampler = ClientManager.Samplers.FirstOrDefault(s => s.Name == sampler.Name);
IsSamplerSelectionEnabled = model.IsSamplerSelectionEnabled; }
SelectedSampler = model.SelectedSampler is null
? null
: new ComfySampler(model.SelectedSampler);
} }
/// <inheritdoc /> /// <inheritdoc />
public override JsonObject SaveStateToJsonObject() public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{ {
return SerializeModel( return parameters with
new SamplerCardModel {
{ Width = Width,
Steps = Steps, Height = Height,
IsDenoiseStrengthEnabled = IsDenoiseStrengthEnabled, Steps = Steps,
DenoiseStrength = DenoiseStrength, CfgScale = CfgScale,
IsCfgScaleEnabled = IsCfgScaleEnabled, Sampler = SelectedSampler?.Name
CfgScale = CfgScale, };
IsDimensionsEnabled = IsDimensionsEnabled, }
Width = Width,
Height = Height,
IsSamplerSelectionEnabled = IsSamplerSelectionEnabled,
SelectedSampler = SelectedSampler?.Name
}
);
}*/
} }

Loading…
Cancel
Save