Browse Source

Add upscale option and fix GAN upscales

pull/165/head
Ionite 1 year ago
parent
commit
14f7f240be
No known key found for this signature in database
  1. 8
      StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml
  2. 51
      StabilityMatrix.Avalonia/Controls/SamplerCard.axaml
  3. 6
      StabilityMatrix.Avalonia/DesignData/DesignData.cs
  4. 37
      StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs
  5. 368
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  6. 12
      StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs
  7. 7
      StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
  8. 23
      StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml
  9. 168
      StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

8
StabilityMatrix.Avalonia/Controls/ImageGalleryCard.axaml

@ -65,11 +65,19 @@
StretchDirection="Both">
<controls:BetterAdvancedImage.ContextFlyout>
<ui:FAMenuFlyout>
<ui:MenuFlyoutItem
Command="{ReflectionBinding #ImageCarousel.DataContext.FlyoutPreviewCommand}"
CommandParameter="{Binding #ImageView.CurrentImage}"
Text="Expanded Preview"
HotKey="Space"
IconSource="ZoomInFilled" />
<ui:MenuFlyoutSeparator/>
<ui:MenuFlyoutItem
IsEnabled="{OnPlatform Windows=True, Default=False}"
Command="{ReflectionBinding #ImageCarousel.DataContext.FlyoutCopyCommand}"
CommandParameter="{Binding #ImageView.CurrentImage}"
Text="Copy"
HotKey="Ctrl+C"
IconSource="Copy" />
</ui:FAMenuFlyout>
</controls:BetterAdvancedImage.ContextFlyout>

51
StabilityMatrix.Avalonia/Controls/SamplerCard.axaml

@ -6,7 +6,7 @@
x:DataType="vmInference:SamplerCardViewModel"
xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData">
<Design.PreviewWith>
<StackPanel MinWidth="300" Spacing="16">
<StackPanel MinWidth="350" Spacing="16">
<controls:SamplerCard DataContext="{x:Static mocks:DesignData.SamplerCardViewModel}"/>
<controls:SamplerCard DataContext="{x:Static mocks:DesignData.SamplerCardViewModelScaleMode}"/>
</StackPanel>
@ -21,7 +21,7 @@
<StackPanel
Margin="8"
HorizontalAlignment="{TemplateBinding HorizontalAlignment}"
Spacing="16">
Spacing="12">
<Grid ColumnDefinitions="Auto,*" RowDefinitions="*,*,*">
<!-- Sampler -->
<TextBlock
@ -77,34 +77,35 @@
<StackPanel>
<!-- Denoise Strength -->
<Grid IsVisible="{Binding IsDenoiseStrengthEnabled}">
<StackPanel>
<Grid ColumnDefinitions="*,Auto">
<TextBlock
VerticalAlignment="Center"
Text="Denoising Strength"/>
<ui:NumberBox
Grid.Column="1"
Margin="4,0,0,0"
ValidationMode="InvalidInputOverwritten"
SmallChange="0.01"
SimpleNumberFormat="F2"
Value="{Binding DenoiseStrength}"
HorizontalAlignment="Stretch"
SpinButtonPlacementMode="Compact"/>
</Grid>
<Slider
Minimum="0"
Maximum="1"
<StackPanel IsVisible="{Binding IsDenoiseStrengthEnabled}">
<Grid ColumnDefinitions="*,Auto">
<TextBlock
VerticalAlignment="Center"
Text="Denoising Strength"/>
<ui:NumberBox
Grid.Column="1"
Margin="4,0,0,0"
ValidationMode="InvalidInputOverwritten"
SmallChange="0.01"
SimpleNumberFormat="F2"
Value="{Binding DenoiseStrength}"
TickFrequency="1"
TickPlacement="BottomRight"/>
</StackPanel>
</Grid>
HorizontalAlignment="Stretch"
MinWidth="70"
SpinButtonPlacementMode="Compact"/>
</Grid>
<Slider
Minimum="0"
Maximum="1"
Value="{Binding DenoiseStrength}"
TickFrequency="1"
Margin="0,0,0,-4"
TickPlacement="BottomRight"/>
</StackPanel>
<!-- Dimensions (Absolute) -->
<Grid
IsVisible="{Binding IsDimensionsEnabled}"
Margin="0,4,0,0"
ColumnDefinitions="*,Auto,*"
RowDefinitions="Auto,*">
<TextBlock

6
StabilityMatrix.Avalonia/DesignData/DesignData.cs

@ -524,7 +524,11 @@ public static class DesignData
}
public static ImageViewerViewModel ImageViewerViewModel
=> DialogFactory.Get<ImageViewerViewModel>();
=> DialogFactory.Get<ImageViewerViewModel>(vm =>
{
vm.Source =
"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/a318ac1f-3ad0-48ac-98cc-79126febcc17/width=1024";
});
public static Indexer Types => new();

37
StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs

@ -11,7 +11,10 @@ using NLog;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.Dialogs;
using StabilityMatrix.Avalonia.Views.Dialogs;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Helper;
@ -21,7 +24,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
public partial class ImageGalleryCardViewModel : ViewModelBase
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ServiceManager<ViewModelBase> vmFactory;
[ObservableProperty]
private bool isPreviewOverlayEnabled;
@ -41,8 +45,10 @@ public partial class ImageGalleryCardViewModel : ViewModelBase
public bool CanNavigateBack => SelectedImageIndex > 0;
public bool CanNavigateForward => SelectedImageIndex < ImageSources.Count - 1;
public ImageGalleryCardViewModel()
public ImageGalleryCardViewModel(ServiceManager<ViewModelBase> vmFactory)
{
this.vmFactory = vmFactory;
ImageSources.CollectionChanged += OnImageSourcesItemsChanged;
}
@ -82,7 +88,7 @@ public partial class ImageGalleryCardViewModel : ViewModelBase
return;
}
Logger.Trace($"FlyoutCopy is copying {image}");
Logger.Trace($"FlyoutCopy is copying bitmap...");
await Task.Run(() =>
{
@ -92,4 +98,29 @@ public partial class ImageGalleryCardViewModel : ViewModelBase
}
});
}
[RelayCommand]
// ReSharper disable once UnusedMember.Local
private async Task FlyoutPreview(IImage? image)
{
if (image is null)
{
Logger.Trace("FlyoutPreview: image is null");
return;
}
Logger.Trace($"FlyoutPreview opening...");
var viewerVm = vmFactory.Get<ImageViewerViewModel>();
viewerVm.Image = (Bitmap) image;
var dialog = new BetterContentDialog
{
Content = new ImageViewerDialog
{
DataContext = viewerVm,
}
};
}
}

368
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -43,22 +43,20 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
private readonly ServiceManager<ViewModelBase> vmFactory;
public IInferenceClientManager ClientManager { get; }
public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; }
public PromptCardViewModel PromptCardViewModel { get; }
public StackCardViewModel StackCardViewModel { get; }
public UpscalerCardViewModel UpscalerCardViewModel =>
StackCardViewModel
.GetCard<StackExpanderViewModel>()
.GetCard<UpscalerCardViewModel>();
public UpscalerCardViewModel UpscalerCardViewModel =>
StackCardViewModel.GetCard<StackExpanderViewModel>().GetCard<UpscalerCardViewModel>();
public SamplerCardViewModel HiresSamplerCardViewModel =>
StackCardViewModel
.GetCard<StackExpanderViewModel>()
.GetCard<SamplerCardViewModel>();
StackCardViewModel.GetCard<StackExpanderViewModel>().GetCard<SamplerCardViewModel>();
public bool IsHiresFixEnabled => StackCardViewModel.GetCard<StackExpanderViewModel>().IsEnabled;
public bool IsUpscaleEnabled => StackCardViewModel.GetCard<StackExpanderViewModel>(1).IsEnabled;
[JsonIgnore]
public ProgressViewModel OutputProgress { get; } = new();
@ -78,44 +76,58 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
ClientManager = inferenceClientManager;
// Get sub view models from service manager
var seedCard = vmFactory.Get<SeedCardViewModel>();
var seedCard = vmFactory.Get<SeedCardViewModel>();
seedCard.GenerateNewSeed();
ImageGalleryCardViewModel = vmFactory.Get<ImageGalleryCardViewModel>();
PromptCardViewModel = vmFactory.Get<PromptCardViewModel>();
StackCardViewModel = vmFactory.Get<StackCardViewModel>();
StackCardViewModel.AddCards(new LoadableViewModelBase[]
{
// Model Card
vmFactory.Get<ModelCardViewModel>(),
// Sampler
vmFactory.Get<SamplerCardViewModel>(),
// Hires Fix
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
StackCardViewModel.AddCards(
new LoadableViewModelBase[]
{
stackExpander.Title = "Hires Fix";
stackExpander.AddCards(new LoadableViewModelBase[]
// Model Card
vmFactory.Get<ModelCardViewModel>(),
// Sampler
vmFactory.Get<SamplerCardViewModel>(),
// Hires Fix
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
stackExpander.Title = "Hires Fix";
stackExpander.AddCards(
new LoadableViewModelBase[]
{
// Hires Fix Upscaler
vmFactory.Get<UpscalerCardViewModel>(),
// Hires Fix Sampler
vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
samplerCard.IsDimensionsEnabled = false;
samplerCard.IsCfgScaleEnabled = false;
samplerCard.IsSamplerSelectionEnabled = false;
samplerCard.IsDenoiseStrengthEnabled = true;
})
}
);
}),
vmFactory.Get<StackExpanderViewModel>(stackExpander =>
{
// Hires Fix Upscaler
vmFactory.Get<UpscalerCardViewModel>(),
// Hires Fix Sampler
vmFactory.Get<SamplerCardViewModel>(samplerCard =>
{
samplerCard.IsDimensionsEnabled = false;
samplerCard.IsCfgScaleEnabled = false;
samplerCard.IsSamplerSelectionEnabled = false;
samplerCard.IsDenoiseStrengthEnabled = true;
})
});
}),
// Seed
seedCard,
// Batch Size
vmFactory.Get<BatchSizeCardViewModel>(),
});
stackExpander.Title = "Upscale";
stackExpander.AddCards(
new LoadableViewModelBase[]
{
// Post processing upscaler
vmFactory.Get<UpscalerCardViewModel>(),
});
}),
// Seed
seedCard,
// Batch Size
vmFactory.Get<BatchSizeCardViewModel>(),
}
);
GenerateImageCommand.WithNotificationErrorHandler(notificationService);
}
@ -123,7 +135,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
private (NodeDictionary prompt, string[] outputs) BuildPrompt()
{
using var _ = new CodeTimer();
var samplerCard = StackCardViewModel.GetCard<SamplerCardViewModel>();
var batchCard = StackCardViewModel.GetCard<BatchSizeCardViewModel>();
var modelCard = StackCardViewModel.GetCard<ModelCardViewModel>();
@ -131,144 +143,179 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
var prompt = new NodeDictionary();
var builder = new ComfyNodeBuilder(prompt);
var checkpointLoader = prompt.AddNamedNode(new NamedComfyNode("CheckpointLoader")
{
ClassType = "CheckpointLoaderSimple",
Inputs = new Dictionary<string, object?>
var checkpointLoader = prompt.AddNamedNode(
new NamedComfyNode("CheckpointLoader")
{
["ckpt_name"] = modelCard.SelectedModelName
ClassType = "CheckpointLoaderSimple",
Inputs = new Dictionary<string, object?>
{
["ckpt_name"] = modelCard.SelectedModelName
}
}
});
);
var checkpointVae = checkpointLoader.GetOutput<VAENodeConnection>(2);
var emptyLatentImage = prompt.AddNamedNode(new NamedComfyNode("EmptyLatentImage")
{
ClassType = "EmptyLatentImage",
Inputs = new Dictionary<string, object?>
var emptyLatentImage = prompt.AddNamedNode(
new NamedComfyNode("EmptyLatentImage")
{
["batch_size"] = batchCard.BatchSize,
["height"] = samplerCard.Height,
["width"] = samplerCard.Width,
ClassType = "EmptyLatentImage",
Inputs = new Dictionary<string, object?>
{
["batch_size"] = batchCard.BatchSize,
["height"] = samplerCard.Height,
["width"] = samplerCard.Width,
}
}
});
var positiveClip = prompt.AddNamedNode(new NamedComfyNode("PositiveCLIP")
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
);
var positiveClip = prompt.AddNamedNode(
new NamedComfyNode("PositiveCLIP")
{
["clip"] = checkpointLoader.GetOutput(1),
["text"] = PromptCardViewModel.PromptDocument.Text,
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
{
["clip"] = checkpointLoader.GetOutput(1),
["text"] = PromptCardViewModel.PromptDocument.Text,
}
}
});
var negativeClip = prompt.AddNamedNode(new NamedComfyNode("NegativeCLIP")
{
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
);
var negativeClip = prompt.AddNamedNode(
new NamedComfyNode("NegativeCLIP")
{
["clip"] = checkpointLoader.GetOutput(1),
["text"] = PromptCardViewModel.NegativePromptDocument.Text,
ClassType = "CLIPTextEncode",
Inputs = new Dictionary<string, object?>
{
["clip"] = checkpointLoader.GetOutput(1),
["text"] = PromptCardViewModel.NegativePromptDocument.Text,
}
}
});
var sampler = prompt.AddNamedNode(ComfyNodeBuilder.KSampler(
"Sampler",
checkpointLoader.GetOutput<ModelNodeConnection>(0),
Convert.ToUInt64(seedCard.Seed),
samplerCard.Steps,
samplerCard.CfgScale,
samplerCard.SelectedSampler?.Name ?? throw new InvalidOperationException("Sampler not selected"),
"normal",
positiveClip.GetOutput<ConditioningNodeConnection>(0),
negativeClip.GetOutput<ConditioningNodeConnection>(0),
emptyLatentImage.GetOutput<LatentNodeConnection>(0),
samplerCard.DenoiseStrength));
);
var sampler = prompt.AddNamedNode(
ComfyNodeBuilder.KSampler(
"Sampler",
checkpointLoader.GetOutput<ModelNodeConnection>(0),
Convert.ToUInt64(seedCard.Seed),
samplerCard.Steps,
samplerCard.CfgScale,
samplerCard.SelectedSampler?.Name
?? throw new InvalidOperationException("Sampler not selected"),
"normal",
positiveClip.GetOutput<ConditioningNodeConnection>(0),
negativeClip.GetOutput<ConditioningNodeConnection>(0),
emptyLatentImage.GetOutput<LatentNodeConnection>(0),
samplerCard.DenoiseStrength
)
);
var lastLatent = sampler.Output;
var lastLatentWidth = samplerCard.Width;
var lastLatentHeight = samplerCard.Height;
var vaeDecoder = prompt.AddNamedNode(new NamedComfyNode("VAEDecoder")
{
ClassType = "VAEDecode",
Inputs = new Dictionary<string, object?>
var vaeDecoder = prompt.AddNamedNode(
new NamedComfyNode("VAEDecoder")
{
["samples"] = sampler.GetOutput(0),
["vae"] = checkpointLoader.GetOutput(2)
ClassType = "VAEDecode",
Inputs = new Dictionary<string, object?>
{
["samples"] = lastLatent,
["vae"] = checkpointLoader.GetOutput(2)
}
}
});
var saveImage = prompt.AddNamedNode(new NamedComfyNode("SaveImage")
{
ClassType = "SaveImage",
Inputs = new Dictionary<string, object?>
);
var saveImage = prompt.AddNamedNode(
new NamedComfyNode("SaveImage")
{
["filename_prefix"] = "SM-Inference",
["images"] = vaeDecoder.GetOutput(0)
ClassType = "SaveImage",
Inputs = new Dictionary<string, object?>
{
["filename_prefix"] = "SM-Inference",
["images"] = vaeDecoder.GetOutput(0)
}
}
});
);
// If hi-res fix is enabled, add the LatentUpscale node and another KSampler node
if (IsHiresFixEnabled)
{
var hiresUpscalerCard = UpscalerCardViewModel;
var hiresSamplerCard = HiresSamplerCardViewModel;
// Requested upscale to this size
var hiresWidth = (int)Math.Floor(lastLatentWidth * hiresUpscalerCard.Scale);
var hiresHeight = (int)Math.Floor(lastLatentHeight * hiresUpscalerCard.Scale);
LatentNodeConnection hiresLatent;
// Select between latent upscale and normal upscale based on the upscale method
var selectedUpscaler = hiresUpscalerCard.SelectedUpscaler;
LatentNodeConnection hiresOutput;
var selectedUpscaler = hiresUpscalerCard.SelectedUpscaler!.Value;
if (selectedUpscaler?.Type == ComfyUpscalerType.Latent)
{
hiresOutput = prompt.AddNamedNode(new NamedComfyNode("LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary<string, object?>
{
["upscale_method"] = hiresUpscalerCard.SelectedUpscaler?.Name,
["width"] = samplerCard.Width * hiresUpscalerCard.Scale,
["height"] = samplerCard.Height * hiresUpscalerCard.Scale,
["crop"] = "disabled",
["samples"] = sampler.Output
}
}).GetOutput<LatentNodeConnection>(0);
}
else if (selectedUpscaler?.Type == ComfyUpscalerType.ESRGAN)
if (selectedUpscaler.Type == ComfyUpscalerType.None)
{
// Convert to image space
var samplerImage = builder.Lambda_LatentToImage(sampler.Output, checkpointVae);
// Do group upscale
var modelUpscaler = builder.Group_UpscaleWithModel("Upscaler",
selectedUpscaler.Value.Name, samplerImage);
// Convert back to latent space
hiresOutput = builder.Lambda_ImageToLatent(modelUpscaler.Output, checkpointVae);
// If no upscaler selected or none, just reroute the latent image
hiresLatent = sampler.Output;
}
else
{
// If no upscaler selected or none, just reroute the latent image
hiresOutput = sampler.Output;
// Otherwise upscale the latent image
hiresLatent = builder.Group_UpscaleToLatent("HiresFix",
lastLatent, checkpointVae, selectedUpscaler, hiresWidth, hiresHeight).Output;
}
var hiresSampler = prompt.AddNamedNode(ComfyNodeBuilder.KSampler(
"HiresSampler",
checkpointLoader.GetOutput<ModelNodeConnection>(0),
Convert.ToUInt64(seedCard.Seed),
hiresSamplerCard.Steps,
hiresSamplerCard.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler name
hiresSamplerCard.SelectedSampler?.Name ?? samplerCard.SelectedSampler?.Name ?? throw new InvalidOperationException("Sampler not selected"),
"normal",
positiveClip.GetOutput<ConditioningNodeConnection>(0),
negativeClip.GetOutput<ConditioningNodeConnection>(0),
hiresOutput,
hiresSamplerCard.DenoiseStrength));
var hiresSampler = prompt.AddNamedNode(
ComfyNodeBuilder.KSampler(
"HiresSampler",
checkpointLoader.GetOutput<ModelNodeConnection>(0),
Convert.ToUInt64(seedCard.Seed),
hiresSamplerCard.Steps,
hiresSamplerCard.CfgScale,
// Use hires sampler name if not null, otherwise use the normal sampler name
hiresSamplerCard.SelectedSampler?.Name
?? samplerCard.SelectedSampler?.Name
?? throw new InvalidOperationException("Sampler not selected"),
"normal",
positiveClip.GetOutput<ConditioningNodeConnection>(0),
negativeClip.GetOutput<ConditioningNodeConnection>(0),
hiresLatent,
hiresSamplerCard.DenoiseStrength
)
);
// Set as last latent
lastLatent = hiresSampler.Output;
lastLatentWidth = hiresWidth;
lastLatentHeight = hiresHeight;
// Reroute the VAEDecoder's input to be from the hires sampler
vaeDecoder.Inputs["samples"] = hiresSampler.Output;
vaeDecoder.Inputs["samples"] = lastLatent;
}
// If upscale is enabled, add another upscale group
if (IsUpscaleEnabled)
{
var postUpscalerCard = StackCardViewModel.GetCard<StackExpanderViewModel>(1)
.GetCard<UpscalerCardViewModel>();
var upscaleWidth = (int)Math.Floor(lastLatentWidth * postUpscalerCard.Scale);
var upscaleHeight = (int)Math.Floor(lastLatentHeight * postUpscalerCard.Scale);
// Build group
var postUpscaleGroup = builder.Group_UpscaleToImage("PostUpscale",
lastLatent, checkpointVae, postUpscalerCard.SelectedUpscaler!.Value,
upscaleWidth, upscaleHeight);
// Remove the original vae decoder
prompt.Remove(vaeDecoder.Name);
// Set as the input for save image
saveImage.Inputs["images"] = postUpscaleGroup.Output;
}
prompt.NormalizeConnectionTypes();
return (prompt, new[] { saveImage.Name });
}
@ -278,8 +325,9 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
OutputProgress.Maximum = args.Maximum;
OutputProgress.IsIndeterminate = false;
OutputProgress.Text = $"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
OutputProgress.Text =
$"({args.Value} / {args.Maximum})"
+ (args.RunningNode != null ? $" {args.RunningNode}" : "");
}
private void OnPreviewImageReceived(object? sender, ComfyWebSocketImageData args)
@ -315,7 +363,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Connect progress handler
// client.ProgressUpdateReceived += OnProgressUpdateReceived;
client.PreviewImageReceived += OnPreviewImageReceived;
ComfyTask? promptTask = null;
try
{
@ -323,9 +371,11 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
cancellationToken.Register(() =>
{
Logger.Info("Cancelling prompt");
client.InterruptPromptAsync(new CancellationTokenSource(5000).Token).SafeFireAndForget();
client
.InterruptPromptAsync(new CancellationTokenSource(5000).Token)
.SafeFireAndForget();
});
try
{
promptTask = await client.QueuePromptAsync(nodes, cancellationToken);
@ -336,7 +386,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
await DialogHelper.CreateApiExceptionDialog(e, "Api Error").ShowAsync();
return;
}
// Register progress handler
promptTask.ProgressUpdate += OnProgressUpdateReceived;
@ -351,9 +401,10 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
);
ImageGalleryCardViewModel.ImageSources.Clear();
var images = imageOutputs[outputNodeNames[0]];
if (images is null) return;
if (images is null)
return;
List<ImageSource> outputImages;
// Use local file path if available, otherwise use remote URL
@ -369,26 +420,27 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
.Select(i => new ImageSource(i.ToUri(client.BaseAddress)))
.ToList();
}
// Download all images to make grid, if multiple
if (outputImages.Count > 1)
{
var loadedImages = outputImages.Select(i =>
SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead())).ToImmutableArray();
var loadedImages = outputImages
.Select(i => SKImage.FromEncodedData(i.LocalFile?.Info.OpenRead()))
.ToImmutableArray();
var grid = ImageProcessor.CreateImageGrid(loadedImages);
// Save to disk
var lastName = outputImages.Last().LocalFile?.Info.Name;
var gridPath = client.OutputImagesDir!.JoinFile($"grid-{lastName}");
await using var fileStream = gridPath.Info.OpenWrite();
await fileStream.WriteAsync(grid.Encode().ToArray(), cancellationToken);
// Insert to start of images
ImageGalleryCardViewModel.ImageSources.Add(new ImageSource(gridPath));
}
// Add rest of images
ImageGalleryCardViewModel.ImageSources.AddRange(outputImages);
}
@ -400,7 +452,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
ImageGalleryCardViewModel.PreviewImage?.Dispose();
ImageGalleryCardViewModel.PreviewImage = null;
ImageGalleryCardViewModel.IsPreviewOverlayEnabled = false;
// client.ProgressUpdateReceived -= OnProgressUpdateReceived;
promptTask?.Dispose();
client.PreviewImageReceived -= OnPreviewImageReceived;

12
StabilityMatrix.Avalonia/ViewModels/Inference/StackExpanderViewModel.cs

@ -1,24 +1,29 @@
using System.Linq;
using System.Text.Json.Nodes;
using CommunityToolkit.Mvvm.ComponentModel;
using Newtonsoft.Json;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
#pragma warning disable CS0657 // Not a valid attribute location for this declaration
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
[View(typeof(StackExpander))]
public partial class StackExpanderViewModel : StackViewModelBase
{
[ObservableProperty] private string? title;
[ObservableProperty] private bool isEnabled;
[ObservableProperty]
[property: JsonIgnore]
private string? title;
[ObservableProperty]
private bool isEnabled;
/// <inheritdoc />
public override void LoadStateFromJsonObject(JsonObject state)
{
var model = DeserializeModel<StackExpanderModel>(state);
Title = model.Title;
IsEnabled = model.IsEnabled;
if (model.Cards is null) return;
@ -37,7 +42,6 @@ public partial class StackExpanderViewModel : StackViewModelBase
{
return SerializeModel(new StackExpanderModel
{
Title = Title,
IsEnabled = IsEnabled,
Cards = Cards.Select(x => x.SaveStateToJsonObject()).ToList()
});

7
StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs

@ -650,8 +650,8 @@ public partial class SettingsViewModel : PageViewModelBase
var imageBox = new ImageViewerDialog()
{
MinWidth = 1500,
MinHeight = 900,
Width = 1000,
Height = 1000,
DataContext = new ImageViewerViewModel()
{
Image = bitmap
@ -661,9 +661,10 @@ public partial class SettingsViewModel : PageViewModelBase
var dialog = new BetterContentDialog
{
MaxDialogWidth = 1000,
MaxDialogHeight = 1000,
FullSizeDesired = true,
Content = imageBox,
CloseButtonText = "Close",
IsFooterVisible = false,
ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled,
};

23
StabilityMatrix.Avalonia/Views/Dialogs/ImageViewerDialog.axaml

@ -8,23 +8,38 @@
xmlns:controls="clr-namespace:StabilityMatrix.Avalonia.Controls"
xmlns:system="clr-namespace:System;assembly=System.Runtime"
xmlns:vmDialogs="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Dialogs"
xmlns:icons="clr-namespace:Projektanker.Icons.Avalonia;assembly=Projektanker.Icons.Avalonia"
VerticalContentAlignment="Stretch"
HorizontalContentAlignment="Stretch"
d:DataContext="{x:Static mocks:DesignData.ImageViewerViewModel}"
x:DataType="vmDialogs:ImageViewerViewModel"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="StabilityMatrix.Avalonia.Views.Dialogs.ImageViewerDialog">
<Grid>
<Grid VerticalAlignment="Stretch" HorizontalAlignment="Stretch">
<controls:AdvancedImageBox
Name="MainImageBox"
RenderOptions.BitmapInterpolationMode="None"
Image="{Binding Image}"/>
Image="{Binding Image}"
Source="{Binding Source}"/>
<!-- The preview tracker -->
<Image
<!--<Image
MinHeight="100"
MinWidth="100"
RenderOptions.BitmapInterpolationMode="HighQuality"
Source="{Binding #MainImageBox.TrackerImage}"
HorizontalAlignment="Left"
VerticalAlignment="Bottom"/>
VerticalAlignment="Bottom"/> -->
<!-- Close button -->
<Grid
VerticalAlignment="Top"
HorizontalAlignment="Right">
<Button
Margin="0,8,8,0"
Classes="transparent"
Command="{Binding OnCloseButtonClick}"
icons:Attached.Icon="fa-solid fa-xmark"/>
</Grid>
</Grid>
</controls:UserControlBase>

168
StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

@ -40,7 +40,7 @@ public class ComfyNodeBuilder
return new NamedComfyNode<ImageNodeConnection>(name)
{
ClassType = "VAEDecode",
Inputs = new Dictionary<string, object?> { ["latent"] = samples.Data, ["vae"] = vae.Data }
Inputs = new Dictionary<string, object?> { ["samples"] = samples.Data, ["vae"] = vae.Data }
};
}
@ -105,6 +105,28 @@ public class ComfyNodeBuilder
};
}
public static NamedComfyNode<ImageNodeConnection> ImageScale(
string name,
ImageNodeConnection image,
string method,
int height,
int width,
bool crop)
{
return new NamedComfyNode<ImageNodeConnection>(name)
{
ClassType = "ImageScale",
Inputs = new Dictionary<string, object?>
{
["image"] = image.Data,
["upscale_method"] = method,
["height"] = height,
["width"] = width,
["crop"] = crop ? "center" : "disabled"
}
};
}
public ImageNodeConnection Lambda_LatentToImage(LatentNodeConnection latent, VAENodeConnection vae)
{
return nodes.AddNamedNode(VAEDecode($"{GetRandomPrefix()}_VAEDecode", latent, vae)).Output;
@ -116,7 +138,7 @@ public class ComfyNodeBuilder
}
/// <summary>
/// Create a upscaling node based on a <see cref="ComfyUpscalerType"/>
/// Create a group node that upscales a given image with a given model
/// </summary>
public NamedComfyNode<ImageNodeConnection> Group_UpscaleWithModel(string name, string modelName, ImageNodeConnection image)
{
@ -128,4 +150,146 @@ public class ComfyNodeBuilder
return upscaler;
}
/// <summary>
/// Create a group node that scales a given image to a given size
/// </summary>
public NamedComfyNode<LatentNodeConnection> Group_UpscaleToLatent(string name,
LatentNodeConnection latent, VAENodeConnection vae,
ComfyUpscaler upscaleInfo, int width, int height)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
return nodes
.AddNamedNode(
new NamedComfyNode<LatentNodeConnection>($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary<string, object?>
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = nodes.AddNamedNode(
VAEDecode(
$"{name}_VAEDecode",
latent,
vae
)
);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// Convert back to latent space
return nodes
.AddNamedNode(
VAEEncode(
$"{name}_VAEEncode",
resizedScaled.Output,
vae
)
);
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
/// <summary>
/// Create a group node that scales a given image to image output
/// </summary>
public NamedComfyNode<ImageNodeConnection> Group_UpscaleToImage(string name,
LatentNodeConnection latent, VAENodeConnection vae,
ComfyUpscaler upscaleInfo, int width, int height)
{
if (upscaleInfo.Type == ComfyUpscalerType.Latent)
{
var latentUpscale = nodes
.AddNamedNode(
new NamedComfyNode<LatentNodeConnection>($"{name}_LatentUpscale")
{
ClassType = "LatentUpscale",
Inputs = new Dictionary<string, object?>
{
["upscale_method"] = upscaleInfo.Name,
["width"] = width,
["height"] = height,
["crop"] = "disabled",
["samples"] = latent.Data,
}
}
);
// Convert to image space
return nodes.AddNamedNode(
VAEDecode(
$"{name}_VAEDecode",
latentUpscale.Output,
vae
)
);
}
if (upscaleInfo.Type == ComfyUpscalerType.ESRGAN)
{
// Convert to image space
var samplerImage = nodes.AddNamedNode(
VAEDecode(
$"{name}_VAEDecode",
latent,
vae
)
);
// Do group upscale
var modelUpscaler = Group_UpscaleWithModel(
$"{name}_ModelUpscale",
upscaleInfo.Name,
samplerImage.Output
);
// Since the model upscale is fixed to model (2x/4x), scale it again to the requested size
var resizedScaled = nodes.AddNamedNode(
ImageScale(
$"{name}_ImageScale",
modelUpscaler.Output,
"bilinear",
height,
width,
false
)
);
// No need to convert back to latent space
return resizedScaled;
}
throw new InvalidOperationException($"Unknown upscaler type: {upscaleInfo.Type}");
}
}

Loading…
Cancel
Save