Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

178 lines
6.1 KiB

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using NLog;
using StabilityMatrix.Avalonia.Extensions;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.Views.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using StabilityMatrix.Core.Services;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(InferenceImageUpscaleView), persistent: true)]
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)]
[ManagedService]
[Transient]
public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly INotificationService notificationService;
[JsonIgnore]
public StackCardViewModel StackCardViewModel { get; }
[JsonPropertyName("Upscaler")]
public UpscalerCardViewModel UpscalerCardViewModel { get; }
[JsonPropertyName("Sharpen")]
public SharpenCardViewModel SharpenCardViewModel { get; }
[JsonPropertyName("SelectImage")]
public SelectImageCardViewModel SelectImageCardViewModel { get; }
public bool IsUpscaleEnabled
{
get => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled;
set => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled = value;
}
public bool IsSharpenEnabled
{
get => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled;
set => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled = value;
}
public InferenceImageUpscaleViewModel(
INotificationService notificationService,
IInferenceClientManager inferenceClientManager,
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> vmFactory
)
: base(vmFactory, inferenceClientManager, notificationService, settingsManager)
{
this.notificationService = notificationService;
UpscalerCardViewModel = vmFactory.Get<UpscalerCardViewModel>();
SharpenCardViewModel = vmFactory.Get<SharpenCardViewModel>();
SelectImageCardViewModel = vmFactory.Get<SelectImageCardViewModel>();
StackCardViewModel = vmFactory.Get<StackCardViewModel>();
StackCardViewModel.AddCards(
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Upscale";
stackExpander.AddCards(UpscalerCardViewModel);
}),
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Sharpen";
stackExpander.AddCards(SharpenCardViewModel);
})
);
}
/// <inheritdoc />
protected override IEnumerable<ImageSource> GetInputImages()
{
if (SelectImageCardViewModel.ImageSource is { } imageSource)
{
yield return imageSource;
}
}
/// <inheritdoc />
protected override void BuildPrompt(BuildPromptEventArgs args)
{
base.BuildPrompt(args);
var builder = args.Builder;
var nodes = builder.Nodes;
// Setup image source
SelectImageCardViewModel.ApplyStep(args);
// If upscale is enabled, add another upscale group
if (IsUpscaleEnabled)
{
var upscaleSize = builder.Connections.PrimarySize.WithScale(UpscalerCardViewModel.Scale);
// Build group
builder.Connections.Primary = builder
.Group_UpscaleToImage(
"Upscale",
builder.GetPrimaryAsImage(),
UpscalerCardViewModel.SelectedUpscaler!.Value,
upscaleSize.Width,
upscaleSize.Height
)
.Output;
}
// If sharpen is enabled, add another sharpen group
if (IsSharpenEnabled)
{
builder.Connections.Primary = nodes
.AddTypedNode(
new ComfyNodeBuilder.ImageSharpen
{
Name = "Sharpen",
Image = builder.GetPrimaryAsImage(),
SharpenRadius = SharpenCardViewModel.SharpenRadius,
Sigma = SharpenCardViewModel.Sigma,
Alpha = SharpenCardViewModel.Alpha
}
)
.Output;
}
builder.SetupOutputImage();
}
/// <inheritdoc />
protected override async Task GenerateImageImpl(GenerateOverrides overrides, CancellationToken cancellationToken)
{
if (!ClientManager.IsConnected)
{
notificationService.Show("Client not connected", "Please connect first");
return;
}
if (SelectImageCardViewModel.ImageSource?.LocalFile?.FullPath is not { } path)
{
notificationService.Show("No image selected", "Please select an image first");
return;
}
foreach (var image in GetInputImages())
{
await ClientManager.UploadInputImageAsync(image, cancellationToken);
}
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides };
BuildPrompt(buildPromptArgs);
var generationArgs = new ImageGenerationEventArgs
{
Client = ClientManager.Client,
Nodes = buildPromptArgs.Builder.ToNodeDictionary(),
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
Parameters = new GenerationParameters { ModelName = UpscalerCardViewModel.SelectedUpscaler?.Name, },
Project = InferenceProjectDocument.FromLoadable(this)
};
await RunGeneration(generationArgs, cancellationToken);
}
}