using System;
using System.Linq;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.Input;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models.Api.Comfy;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
[ManagedService]
[Transient]
public partial class HiresFixModule : ModuleBase
{
///
public override bool IsSettingsEnabled => true;
///
public override IRelayCommand SettingsCommand => OpenSettingsDialogCommand;
///
public HiresFixModule(ServiceManager vmFactory)
: base(vmFactory)
{
Title = "HiresFix";
AddCards(
vmFactory.Get(),
vmFactory.Get(vmSampler =>
{
vmSampler.IsDenoiseStrengthEnabled = true;
})
);
}
[RelayCommand]
private async Task OpenSettingsDialog()
{
var gridVm = VmFactory.Get(vm =>
{
vm.Title = $"{Title} {Resources.Label_Settings}";
vm.SelectedObject = Cards.ToArray();
vm.IncludeCategories = ["Settings"];
});
await gridVm.GetDialog().ShowAsync();
}
///
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var builder = e.Builder;
var upscaleCard = GetCard();
var samplerCard = GetCard();
// Get new latent size
var hiresSize = builder.Connections.PrimarySize.WithScale(upscaleCard.Scale);
// Select between latent upscale and normal upscale based on the upscale method
var selectedUpscaler = upscaleCard.SelectedUpscaler!.Value;
// If upscaler selected, upscale latent image first
if (selectedUpscaler.Type != ComfyUpscalerType.None)
{
builder.Connections.Primary = builder.Group_Upscale(
builder.Nodes.GetUniqueName("HiresFix"),
builder.Connections.Primary.Unwrap(),
builder.Connections.GetDefaultVAE(),
selectedUpscaler,
hiresSize.Width,
hiresSize.Height
);
}
var hiresSampler = builder
.Nodes
.AddTypedNode(
new ComfyNodeBuilder.KSampler
{
Name = builder.Nodes.GetUniqueName("HiresFix_Sampler"),
Model = builder.Connections.GetRefinerOrBaseModel(),
Seed = builder.Connections.Seed,
Steps = samplerCard.Steps,
Cfg = samplerCard.CfgScale,
SamplerName =
samplerCard.SelectedSampler?.Name
?? e.Builder.Connections.PrimarySampler?.Name
?? throw new ArgumentException("No PrimarySampler"),
Scheduler =
samplerCard.SelectedScheduler?.Name
?? e.Builder.Connections.PrimaryScheduler?.Name
?? throw new ArgumentException("No PrimaryScheduler"),
Positive = builder.Connections.GetRefinerOrBaseConditioning().Positive,
Negative = builder.Connections.GetRefinerOrBaseConditioning().Negative,
LatentImage = builder.GetPrimaryAsLatent(),
Denoise = samplerCard.DenoiseStrength
}
);
// Set as primary
builder.Connections.Primary = hiresSampler.Output;
builder.Connections.PrimarySize = hiresSize;
}
}