using System; using System.Diagnostics.CodeAnalysis; using System.Drawing; using System.Linq; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia.Controls.Shapes; using Avalonia.Threading; using DynamicData.Binding; using NLog; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using Path = System.IO.Path; #pragma warning disable CS0657 // Not a valid attribute location for this declaration namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceImageUpscaleView), persistent: true)] [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] public class InferenceImageUpscaleViewModel : InferenceGenerationViewModelBase { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; [JsonIgnore] public StackCardViewModel StackCardViewModel { get; } [JsonPropertyName("Upscaler")] public UpscalerCardViewModel UpscalerCardViewModel { get; } [JsonPropertyName("Sharpen")] public SharpenCardViewModel SharpenCardViewModel { get; } [JsonPropertyName("SelectImage")] public SelectImageCardViewModel SelectImageCardViewModel { get; } public bool IsUpscaleEnabled { get => StackCardViewModel.GetCard().IsEnabled; set => StackCardViewModel.GetCard().IsEnabled = value; } public bool IsSharpenEnabled { get => StackCardViewModel.GetCard(1).IsEnabled; set => StackCardViewModel.GetCard(1).IsEnabled = value; } public InferenceImageUpscaleViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, ServiceManager vmFactory ) : base(vmFactory, inferenceClientManager, notificationService) { this.notificationService = notificationService; UpscalerCardViewModel = vmFactory.Get(); SharpenCardViewModel = vmFactory.Get(); SelectImageCardViewModel = vmFactory.Get(); StackCardViewModel = vmFactory.Get(); StackCardViewModel.AddCards( new LoadableViewModelBase[] { // Upscaler vmFactory.Get(stackExpander => { stackExpander.Title = "Upscale"; stackExpander.AddCards(new LoadableViewModelBase[] { UpscalerCardViewModel }); }), // Sharpen vmFactory.Get(stackExpander => { stackExpander.Title = "Sharpen"; stackExpander.AddCards(new LoadableViewModelBase[] { SharpenCardViewModel }); }) } ); // On any new images, copy to input dir SelectImageCardViewModel .WhenPropertyChanged(x => x.ImageSource) .Subscribe(e => { if (e.Value?.LocalFile?.FullPath is { } path) { ClientManager.CopyImageToInputAsync(path).SafeFireAndForget(); } }); } /// protected override void BuildPrompt(BuildPromptEventArgs args) { base.BuildPrompt(args); var builder = args.Builder; var nodes = builder.Nodes; // Get source image var sourceImage = SelectImageCardViewModel.ImageSource; var sourceImageRelativePath = Path.Combine("Inference", sourceImage!.LocalFile!.Name); var sourceImageSize = SelectImageCardViewModel.CurrentBitmapSize ?? throw new InvalidOperationException("Source image size is null"); // Set source size builder.Connections.ImageSize = sourceImageSize; // Load source var loadImage = nodes.AddNamedNode( ComfyNodeBuilder.LoadImage("LoadImage", sourceImageRelativePath) ); builder.Connections.Image = loadImage.Output1; // If upscale is enabled, add another upscale group if (IsUpscaleEnabled) { var upscaleSize = builder.Connections.GetScaledImageSize(UpscalerCardViewModel.Scale); // Build group var upscaleGroup = builder.Group_UpscaleToImage( "Upscale", builder.Connections.Image!, UpscalerCardViewModel.SelectedUpscaler!.Value, upscaleSize.Width, upscaleSize.Height ); // Set as the image output builder.Connections.Image = upscaleGroup.Output; } // If sharpen is enabled, add another sharpen group if (IsSharpenEnabled) { var sharpenGroup = nodes.AddNamedNode( ComfyNodeBuilder.ImageSharpen( "Sharpen", builder.Connections.Image, SharpenCardViewModel.SharpenRadius, SharpenCardViewModel.Sigma, SharpenCardViewModel.Alpha ) ); // Set as the image output builder.Connections.Image = sharpenGroup.Output; } builder.SetupOutputImage(); } /// protected override async Task GenerateImageImpl( GenerateOverrides overrides, CancellationToken cancellationToken ) { if (!ClientManager.IsConnected) { notificationService.Show("Client not connected", "Please connect first"); return; } if (SelectImageCardViewModel.ImageSource?.LocalFile?.FullPath is not { } path) { notificationService.Show("No image selected", "Please select an image first"); return; } await ClientManager.CopyImageToInputAsync(path, cancellationToken); var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides }; BuildPrompt(buildPromptArgs); var generationArgs = new ImageGenerationEventArgs { Client = ClientManager.Client, Nodes = buildPromptArgs.Builder.ToNodeDictionary(), OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(), Parameters = new GenerationParameters { ModelName = UpscalerCardViewModel.SelectedUpscaler?.Name, }, Project = InferenceProjectDocument.FromLoadable(this) }; await RunGeneration(generationArgs, cancellationToken); } }