Browse Source

Add save load support for files with GenerationParameters

pull/165/head
Ionite 1 year ago
parent
commit
e2e85c6f57
No known key found for this signature in database
  1. 15
      StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs
  2. 76
      StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs
  3. 28
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  4. 49
      StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs
  5. 20
      StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs
  6. 55
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

15
StabilityMatrix.Avalonia/Models/IParametersLoadableState.cs

@ -0,0 +1,15 @@
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.Models;
public interface IParametersLoadableState
{
void LoadStateFromParameters(GenerationParameters parameters);
GenerationParameters SaveStateToParameters(GenerationParameters parameters);
public GenerationParameters SaveStateToParameters()
{
return SaveStateToParameters(new GenerationParameters());
}
}

76
StabilityMatrix.Avalonia/ViewModels/Base/InferenceTabViewModelBase.cs

@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
@ -6,6 +7,7 @@ using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Input;
using Avalonia.Platform.Storage;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
@ -13,6 +15,8 @@ using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.ViewModels.Inference;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
@ -144,6 +148,66 @@ public abstract partial class InferenceTabViewModelBase
GC.SuppressFinalize(this);
}
private bool TryLoadImageMetadata(FilePath? filePath)
{
if (filePath is not { Exists: true })
return false;
var metadata = ImageMetadata.GetAllFileMetadata(filePath);
// Has SMProject metadata
if (metadata.SMProject is not null)
{
var project = JsonSerializer.Deserialize<InferenceProjectDocument>(metadata.SMProject);
// Check project type matches
if (project?.ProjectType.ToViewModelType() == GetType() && project.State is not null)
{
LoadStateFromJsonObject(project.State);
}
else
{
return false;
}
// Load image
if (this is IImageGalleryComponent imageGalleryComponent)
{
imageGalleryComponent.LoadImagesToGallery(new ImageSource(filePath));
}
return true;
}
// Has generic metadata
if (metadata.Parameters is { } parametersString)
{
if (!GenerationParameters.TryParse(parametersString, out var parameters))
{
return false;
}
if (this is IParametersLoadableState paramsLoadableVm)
{
paramsLoadableVm.LoadStateFromParameters(parameters);
}
else
{
return false;
}
// Load image
if (this is IImageGalleryComponent imageGalleryComponent)
{
imageGalleryComponent.LoadImagesToGallery(new ImageSource(filePath));
}
return true;
}
return false;
}
/// <inheritdoc />
public void DragOver(object? sender, DragEventArgs e)
{
@ -162,10 +226,10 @@ public abstract partial class InferenceTabViewModelBase
if (e.Data.GetDataFormats().Contains(DataFormats.Files))
{
e.Handled = true;
e.DragEffects = DragDropEffects.None;
return;
}
// Other kinds - not supported
e.DragEffects = DragDropEffects.None;
}
@ -214,6 +278,16 @@ public abstract partial class InferenceTabViewModelBase
if (e.Data.GetDataFormats().Contains(DataFormats.Files))
{
e.Handled = true;
if (e.Data.Get(DataFormats.Files) is IEnumerable<IStorageItem> files)
{
var paths = files.Select(f => f.TryGetLocalPath()).ToList();
if (paths.FirstOrDefault() is { } file)
{
Dispatcher.UIThread.Post(() => TryLoadImageMetadata(file));
}
}
}
}
}

28
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -25,7 +25,9 @@ using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.Infere
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceTextToImageView), persistent: true)]
public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase
public class InferenceTextToImageViewModel
: InferenceGenerationViewModelBase,
IParametersLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@ -327,4 +329,28 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase
await RunGeneration(generationArgs, cancellationToken);
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
PromptCardViewModel.LoadStateFromParameters(parameters);
SamplerCardViewModel.LoadStateFromParameters(parameters);
SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed);
ModelCardViewModel.LoadStateFromParameters(parameters);
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
parameters = PromptCardViewModel.SaveStateToParameters(parameters);
parameters = SamplerCardViewModel.SaveStateToParameters(parameters);
parameters.Seed = (ulong)SeedCardViewModel.Seed;
parameters = ModelCardViewModel.SaveStateToParameters(parameters);
return parameters;
}
}

49
StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs

@ -1,7 +1,9 @@
using System.Linq;
using System;
using System.Linq;
using System.Text.Json.Nodes;
using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
@ -10,7 +12,7 @@ using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(ModelCard))]
public partial class ModelCardViewModel : LoadableViewModelBase
public partial class ModelCardViewModel : LoadableViewModelBase, IParametersLoadableState
{
[ObservableProperty]
private HybridModelFile? selectedModel;
@ -71,4 +73,47 @@ public partial class ModelCardViewModel : LoadableViewModelBase
public string? SelectedVaeName { get; init; }
public bool IsVaeSelectionEnabled { get; init; }
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
if (parameters.ModelName is not { } paramsModelName)
return;
var currentModels = ClientManager.Models;
HybridModelFile? model;
// First try hash match
if (parameters.ModelHash is not null)
{
model = currentModels.FirstOrDefault(
m =>
m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256
&& sha256.StartsWith(
parameters.ModelHash,
StringComparison.InvariantCultureIgnoreCase
)
);
}
else
{
// Name matches
model = currentModels.FirstOrDefault(m => m.FileName.EndsWith(paramsModelName));
model ??= currentModels.FirstOrDefault(
m => m.ShortDisplayName.StartsWith(paramsModelName)
);
}
if (model is not null)
{
SelectedModel = model;
}
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with { ModelName = SelectedModel?.FileName };
}
}

20
StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs

@ -22,12 +22,13 @@ using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Exceptions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper.Cache;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(PromptCard))]
public partial class PromptCardViewModel : LoadableViewModelBase
public partial class PromptCardViewModel : LoadableViewModelBase, IParametersLoadableState
{
private readonly IModelIndexService modelIndexService;
@ -284,4 +285,21 @@ public partial class PromptCardViewModel : LoadableViewModelBase
PromptDocument.Text = model.Prompt ?? "";
NegativePromptDocument.Text = model.NegativePrompt ?? "";
}
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
PromptDocument.Text = parameters.PositivePrompt ?? "";
NegativePromptDocument.Text = parameters.NegativePrompt ?? "";
}
/// <inheritdoc />
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return parameters with
{
PositivePrompt = PromptDocument.Text,
NegativePrompt = NegativePromptDocument.Text
};
}
}

55
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -1,15 +1,18 @@
using System.Text.Json.Serialization;
using System.Linq;
using System.Text.Json.Serialization;
using CommunityToolkit.Mvvm.ComponentModel;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(SamplerCard))]
public partial class SamplerCardViewModel : LoadableViewModelBase
public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLoadableState
{
[ObservableProperty]
private bool isRefinerStepsEnabled;
@ -61,42 +64,32 @@ public partial class SamplerCardViewModel : LoadableViewModelBase
ClientManager = clientManager;
}
/*/// <inheritdoc />
public override void LoadStateFromJsonObject(JsonObject state)
/// <inheritdoc />
public void LoadStateFromParameters(GenerationParameters parameters)
{
Width = parameters.Width;
Height = parameters.Height;
Steps = parameters.Steps;
CfgScale = parameters.CfgScale;
if (parameters.GetComfySamplers() is { } paramSamplers)
{
var model = DeserializeModel<SamplerCardModel>(state);
Steps = model.Steps;
IsDenoiseStrengthEnabled = model.IsDenoiseStrengthEnabled;
DenoiseStrength = model.DenoiseStrength;
IsCfgScaleEnabled = model.IsCfgScaleEnabled;
CfgScale = model.CfgScale;
IsDimensionsEnabled = model.IsDimensionsEnabled;
Width = model.Width;
Height = model.Height;
IsSamplerSelectionEnabled = model.IsSamplerSelectionEnabled;
SelectedSampler = model.SelectedSampler is null
? null
: new ComfySampler(model.SelectedSampler);
var (sampler, scheduler) = paramSamplers;
SelectedSampler = ClientManager.Samplers.FirstOrDefault(s => s.Name == sampler.Name);
}
}
/// <inheritdoc />
public override JsonObject SaveStateToJsonObject()
public GenerationParameters SaveStateToParameters(GenerationParameters parameters)
{
return SerializeModel(
new SamplerCardModel
return parameters with
{
Steps = Steps,
IsDenoiseStrengthEnabled = IsDenoiseStrengthEnabled,
DenoiseStrength = DenoiseStrength,
IsCfgScaleEnabled = IsCfgScaleEnabled,
CfgScale = CfgScale,
IsDimensionsEnabled = IsDimensionsEnabled,
Width = Width,
Height = Height,
IsSamplerSelectionEnabled = IsSamplerSelectionEnabled,
SelectedSampler = SelectedSampler?.Name
Steps = Steps,
CfgScale = CfgScale,
Sampler = SelectedSampler?.Name
};
}
);
}*/
}

Loading…
Cancel
Save