You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
203 lines
7.0 KiB
203 lines
7.0 KiB
using System; |
|
using System.Diagnostics.CodeAnalysis; |
|
using System.Drawing; |
|
using System.Linq; |
|
using System.Text.Json.Serialization; |
|
using System.Threading; |
|
using System.Threading.Tasks; |
|
using AsyncAwaitBestPractices; |
|
using Avalonia.Controls.Shapes; |
|
using Avalonia.Threading; |
|
using DynamicData.Binding; |
|
using NLog; |
|
using StabilityMatrix.Avalonia.Extensions; |
|
using StabilityMatrix.Avalonia.Models; |
|
using StabilityMatrix.Avalonia.Models.Inference; |
|
using StabilityMatrix.Avalonia.Services; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Avalonia.Views.Inference; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.Api.Comfy.Nodes; |
|
using Path = System.IO.Path; |
|
|
|
#pragma warning disable CS0657 // Not a valid attribute location for this declaration |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.Inference; |
|
|
|
[View(typeof(InferenceImageUpscaleView), persistent: true)] |
|
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] |
|
public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase |
|
{ |
|
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); |
|
|
|
private readonly INotificationService notificationService; |
|
|
|
[JsonIgnore] |
|
public StackCardViewModel StackCardViewModel { get; } |
|
|
|
[JsonPropertyName("Upscaler")] |
|
public UpscalerCardViewModel UpscalerCardViewModel { get; } |
|
|
|
[JsonPropertyName("Sharpen")] |
|
public SharpenCardViewModel SharpenCardViewModel { get; } |
|
|
|
[JsonPropertyName("SelectImage")] |
|
public SelectImageCardViewModel SelectImageCardViewModel { get; } |
|
|
|
public bool IsUpscaleEnabled |
|
{ |
|
get => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled; |
|
set => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled = value; |
|
} |
|
|
|
public bool IsSharpenEnabled |
|
{ |
|
get => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled; |
|
set => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled = value; |
|
} |
|
|
|
public InferenceImageUpscaleViewModel( |
|
INotificationService notificationService, |
|
IInferenceClientManager inferenceClientManager, |
|
ServiceManager<ViewModelBase> vmFactory |
|
) |
|
: base(vmFactory, inferenceClientManager, notificationService) |
|
{ |
|
this.notificationService = notificationService; |
|
|
|
UpscalerCardViewModel = vmFactory.Get<UpscalerCardViewModel>(); |
|
SharpenCardViewModel = vmFactory.Get<SharpenCardViewModel>(); |
|
SelectImageCardViewModel = vmFactory.Get<SelectImageCardViewModel>(); |
|
|
|
StackCardViewModel = vmFactory.Get<StackCardViewModel>(); |
|
StackCardViewModel.AddCards( |
|
new LoadableViewModelBase[] |
|
{ |
|
// Upscaler |
|
vmFactory.Get<StackExpanderViewModel>(stackExpander => |
|
{ |
|
stackExpander.Title = "Upscale"; |
|
stackExpander.AddCards(new LoadableViewModelBase[] { UpscalerCardViewModel }); |
|
}), |
|
// Sharpen |
|
vmFactory.Get<StackExpanderViewModel>(stackExpander => |
|
{ |
|
stackExpander.Title = "Sharpen"; |
|
stackExpander.AddCards(new LoadableViewModelBase[] { SharpenCardViewModel }); |
|
}) |
|
} |
|
); |
|
|
|
// On any new images, copy to input dir |
|
SelectImageCardViewModel |
|
.WhenPropertyChanged(x => x.ImageSource) |
|
.Subscribe(e => |
|
{ |
|
if (e.Value?.LocalFile?.FullPath is { } path) |
|
{ |
|
ClientManager.CopyImageToInputAsync(path).SafeFireAndForget(); |
|
} |
|
}); |
|
} |
|
|
|
/// <inheritdoc /> |
|
protected override void BuildPrompt(BuildPromptEventArgs args) |
|
{ |
|
base.BuildPrompt(args); |
|
|
|
var builder = args.Builder; |
|
var nodes = builder.Nodes; |
|
|
|
// Get source image |
|
var sourceImage = SelectImageCardViewModel.ImageSource; |
|
var sourceImageRelativePath = Path.Combine("Inference", sourceImage!.LocalFile!.Name); |
|
var sourceImageSize = |
|
SelectImageCardViewModel.CurrentBitmapSize |
|
?? throw new InvalidOperationException("Source image size is null"); |
|
|
|
// Set source size |
|
builder.Connections.ImageSize = sourceImageSize; |
|
|
|
// Load source |
|
var loadImage = nodes.AddNamedNode( |
|
ComfyNodeBuilder.LoadImage("LoadImage", sourceImageRelativePath) |
|
); |
|
builder.Connections.Image = loadImage.Output1; |
|
|
|
// If upscale is enabled, add another upscale group |
|
if (IsUpscaleEnabled) |
|
{ |
|
var upscaleSize = builder.Connections.GetScaledImageSize(UpscalerCardViewModel.Scale); |
|
|
|
// Build group |
|
var upscaleGroup = builder.Group_UpscaleToImage( |
|
"Upscale", |
|
builder.Connections.Image!, |
|
UpscalerCardViewModel.SelectedUpscaler!.Value, |
|
upscaleSize.Width, |
|
upscaleSize.Height |
|
); |
|
|
|
// Set as the image output |
|
builder.Connections.Image = upscaleGroup.Output; |
|
} |
|
|
|
// If sharpen is enabled, add another sharpen group |
|
if (IsSharpenEnabled) |
|
{ |
|
var sharpenGroup = nodes.AddNamedNode( |
|
ComfyNodeBuilder.ImageSharpen( |
|
"Sharpen", |
|
builder.Connections.Image, |
|
SharpenCardViewModel.SharpenRadius, |
|
SharpenCardViewModel.Sigma, |
|
SharpenCardViewModel.Alpha |
|
) |
|
); |
|
|
|
// Set as the image output |
|
builder.Connections.Image = sharpenGroup.Output; |
|
} |
|
|
|
builder.SetupOutputImage(); |
|
} |
|
|
|
/// <inheritdoc /> |
|
protected override async Task GenerateImageImpl( |
|
GenerateOverrides overrides, |
|
CancellationToken cancellationToken |
|
) |
|
{ |
|
if (!ClientManager.IsConnected) |
|
{ |
|
notificationService.Show("Client not connected", "Please connect first"); |
|
return; |
|
} |
|
|
|
if (SelectImageCardViewModel.ImageSource?.LocalFile?.FullPath is not { } path) |
|
{ |
|
notificationService.Show("No image selected", "Please select an image first"); |
|
return; |
|
} |
|
|
|
await ClientManager.CopyImageToInputAsync(path, cancellationToken); |
|
|
|
var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides }; |
|
BuildPrompt(buildPromptArgs); |
|
|
|
var generationArgs = new ImageGenerationEventArgs |
|
{ |
|
Client = ClientManager.Client, |
|
Nodes = buildPromptArgs.Builder.ToNodeDictionary(), |
|
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(), |
|
Parameters = new GenerationParameters |
|
{ |
|
ModelName = UpscalerCardViewModel.SelectedUpscaler?.Name, |
|
}, |
|
Project = InferenceProjectDocument.FromLoadable(this) |
|
}; |
|
|
|
await RunGeneration(generationArgs, cancellationToken); |
|
} |
|
}
|
|
|