Browse Source

Add model based completion autocompletes

pull/165/head
Ionite 1 year ago
parent
commit
0984944c75
No known key found for this signature in database
  1. 22
      StabilityMatrix.Avalonia/Controls/ModelCard.axaml
  2. 29
      StabilityMatrix.Avalonia/Controls/PromptCard.axaml
  3. 2
      StabilityMatrix.Avalonia/DesignData/MockCompletionProvider.cs
  4. 42
      StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs
  5. 3
      StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs
  6. 206
      StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs
  7. 2
      StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs
  8. 11
      StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs
  9. 154
      StabilityMatrix.Avalonia/Services/InferenceClientManager.cs
  10. 26
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

22
StabilityMatrix.Avalonia/Controls/ModelCard.axaml

@ -7,9 +7,11 @@
xmlns:models="clr-namespace:StabilityMatrix.Core.Models;assembly=StabilityMatrix.Core"
x:DataType="inference:ModelCardViewModel">
<Design.PreviewWith>
<StackPanel MinWidth="300">
<controls:ModelCard DataContext="{x:Static mocks:DesignData.ModelCardViewModel}"/>
</StackPanel>
<Panel Width="400" Height="200">
<StackPanel VerticalAlignment="Center" Width="300">
<controls:ModelCard DataContext="{x:Static mocks:DesignData.ModelCardViewModel}"/>
</StackPanel>
</Panel>
</Design.PreviewWith>
<Style Selector="controls|ModelCard">
@ -21,13 +23,14 @@
<!-- Model -->
<TextBlock
Grid.Column="0"
TextAlignment="Left"
VerticalAlignment="Center"
HorizontalAlignment="Left"
MinWidth="60"
Text="Model" />
<ui:FAComboBox
Grid.Row="0"
Grid.Column="1"
IsTextSearchEnabled="True"
HorizontalAlignment="Stretch"
ItemsSource="{Binding ClientManager.Models}"
DisplayMemberBinding="{Binding ShortDisplayName}"
@ -69,6 +72,13 @@
Grid.Column="2"
Margin="8,0,0,0">
<ui:SymbolIcon FontSize="16" Symbol="Setting" />
<Button.Flyout>
<ui:FAMenuFlyout Placement="BottomEdgeAlignedLeft">
<ui:ToggleMenuFlyoutItem
IsChecked="{Binding IsVaeSelectionEnabled}"
Text="VAE"/>
</ui:FAMenuFlyout>
</Button.Flyout>
</Button>
<!-- VAE -->
@ -76,7 +86,9 @@
Grid.Column="0"
Grid.Row="1"
MinWidth="60"
IsVisible="{Binding IsVaeSelectionEnabled}"
Margin="0,8,0,0"
TextAlignment="Left"
VerticalAlignment="Center"
Text="VAE" />
@ -85,6 +97,8 @@
Grid.Column="1"
Grid.ColumnSpan="2"
Margin="0,8,0,0"
IsTextSearchEnabled="True"
IsVisible="{Binding IsVaeSelectionEnabled}"
HorizontalAlignment="Stretch"
ItemsSource="{Binding ClientManager.VaeModels}"
DisplayMemberBinding="{Binding ShortDisplayName}"

29
StabilityMatrix.Avalonia/Controls/PromptCard.axaml

@ -7,6 +7,7 @@
xmlns:i="clr-namespace:Avalonia.Xaml.Interactivity;assembly=Avalonia.Xaml.Interactivity"
xmlns:behaviors="clr-namespace:StabilityMatrix.Avalonia.Behaviors"
xmlns:icons="clr-namespace:Projektanker.Icons.Avalonia;assembly=Projektanker.Icons.Avalonia"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
x:DataType="vmInference:PromptCardViewModel">
<Design.PreviewWith>
<Grid Height="600" Width="600">
@ -55,6 +56,29 @@
<icons:Icon Value="fa-solid fa-caret-up" Margin="8,0" FontSize="10" />
</StackPanel>
<StackPanel
Grid.Row="0"
Grid.Column="1"
Margin="0,0,0,4"
Orientation="Horizontal"
HorizontalAlignment="Right">
<Button
Classes="transparent-full"
Padding="10,4"
Margin="0,-2,0,0"
VerticalAlignment="Top"
VerticalContentAlignment="Top"
icons:Attached.Icon="fa-solid fa-question"/>
<Button
Content="Show Tokens"
Padding="8,4"
IsVisible="{Binding SharedState.IsDebugMode}"
Command="{Binding DebugShowTokensCommand}"/>
</StackPanel>
<ExperimentalAcrylicBorder
Grid.Row="1"
Grid.Column="0"
@ -72,6 +96,8 @@
IsEnabled="{Binding IsAutoCompletionEnabled}"
CompletionProvider="{Binding CompletionProvider}"
TokenizerProvider="{Binding TokenizerProvider}"/>
<behaviors:TextEditorToolTipBehavior
TokenizerProvider="{Binding TokenizerProvider}"/>
</i:Interaction.Behaviors>
</avaloniaEdit:TextEditor>
@ -108,9 +134,12 @@
IsEnabled="{Binding IsAutoCompletionEnabled}"
CompletionProvider="{Binding CompletionProvider}"
TokenizerProvider="{Binding TokenizerProvider}"/>
<behaviors:TextEditorToolTipBehavior
TokenizerProvider="{Binding TokenizerProvider}"/>
</i:Interaction.Behaviors>
</avaloniaEdit:TextEditor>
</ExperimentalAcrylicBorder>
</Grid>
</Grid>

2
StabilityMatrix.Avalonia/DesignData/MockCompletionProvider.cs

@ -27,7 +27,7 @@ public class MockCompletionProvider : ICompletionProvider
}
/// <inheritdoc />
public IEnumerable<ICompletionData> GetCompletions(string searchTerm, int itemsCount, bool suggest)
public IEnumerable<ICompletionData> GetCompletions(TextCompletionRequest completionRequest, int itemsCount, bool suggest)
{
return Array.Empty<ICompletionData>();
}

42
StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using DynamicData.Binding;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
@ -13,26 +14,31 @@ public class MockInferenceClientManager : ObservableObject, IInferenceClientMana
{
public ComfyClient? Client { get; set; }
public IReadOnlyCollection<HybridModelFile>? Models { get; set; }
public IReadOnlyCollection<HybridModelFile>? VaeModels { get; set; }
public IObservableCollection<HybridModelFile> Models { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IReadOnlyCollection<ComfySampler>? Samplers { get; set; } = new ComfySampler[]
{
new("euler_ancestral"),
new("euler"),
new("lms"),
new("heun"),
new("dpm_2"),
new("dpm_2_ancestral")
};
public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>(new ComfySampler[]
{
new("euler_ancestral"),
new("euler"),
new("lms"),
new("heun"),
new("dpm_2"),
new("dpm_2_ancestral")
});
public IObservableCollection<ComfyUpscaler> Upscalers { get; } =
new ObservableCollectionExtended<ComfyUpscaler>(new ComfyUpscaler[]
{
new("nearest-exact", ComfyUpscalerType.Latent),
new("bicubic", ComfyUpscalerType.Latent),
new("ESRGAN-4x", ComfyUpscalerType.ESRGAN)
});
public IReadOnlyCollection<ComfyUpscaler>? Upscalers { get; set; } = new ComfyUpscaler[]
{
new("nearest-exact", ComfyUpscalerType.Latent),
new("bicubic", ComfyUpscalerType.Latent),
new("ESRGAN-4x", ComfyUpscalerType.ESRGAN)
};
public bool IsConnected { get; set; }
public Task ConnectAsync()

3
StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs

@ -8,6 +8,9 @@ namespace StabilityMatrix.Avalonia.DesignData;
public class MockModelIndexService : IModelIndexService
{
/// <inheritdoc />
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; } = new();
/// <inheritdoc />
public Task RefreshIndex()
{

206
StabilityMatrix.Avalonia/Models/TagCompletion/CompletionProvider.cs

@ -14,7 +14,9 @@ using Nito.AsyncEx;
using NLog;
using StabilityMatrix.Avalonia.Controls.CodeCompletion;
using StabilityMatrix.Avalonia.Helpers;
using StabilityMatrix.Avalonia.Models.Inference.Tokens;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Helper;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
@ -29,26 +31,33 @@ public class CompletionProvider : ICompletionProvider
private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService;
private readonly IModelIndexService modelIndexService;
private readonly AsyncLock loadLock = new();
private readonly Dictionary<string, TagCsvEntry> entries = new();
private InMemoryIndexSearcher? searcher;
public bool IsLoaded => searcher is not null;
public Func<string, string>? PrepareInsertionText
=> settingsManager.Settings.IsCompletionRemoveUnderscoresEnabled
? PrepareInsertionText_RemoveUnderscores : null;
public CompletionProvider(ISettingsManager settingsManager, INotificationService notificationService)
public Func<string, string>? PrepareInsertionText =>
settingsManager.Settings.IsCompletionRemoveUnderscoresEnabled
? PrepareInsertionText_RemoveUnderscores
: null;
public CompletionProvider(
ISettingsManager settingsManager,
INotificationService notificationService,
IModelIndexService modelIndexService
)
{
this.settingsManager = settingsManager;
this.notificationService = notificationService;
this.modelIndexService = modelIndexService;
// Attach to load from set file on initial settings load
settingsManager.Loaded += (_, _) => UpdateTagCompletionCsv();
// Also load when TagCompletionCsv property changes
settingsManager.SettingsPropertyChanged += (_, args) =>
{
@ -57,25 +66,26 @@ public class CompletionProvider : ICompletionProvider
UpdateTagCompletionCsv();
}
};
// If library already loaded, start a background load
if (settingsManager.IsLibraryDirSet)
{
UpdateTagCompletionCsv();
}
return;
void UpdateTagCompletionCsv()
{
var csvPath = settingsManager.Settings.TagCompletionCsv;
if (csvPath is null) return;
if (csvPath is null)
return;
var fullPath = settingsManager.TagsDirectory.JoinFile(csvPath);
BackgroundLoadFromFile(fullPath);
}
}
private static string PrepareInsertionText_RemoveUnderscores(string text)
{
return text.Replace("_", " ");
@ -84,62 +94,70 @@ public class CompletionProvider : ICompletionProvider
/// <inheritdoc />
public void BackgroundLoadFromFile(FilePath path, bool recreate = false)
{
LoadFromFile(path, recreate).SafeFireAndForget(onException: exception =>
{
const string title = "Failed to load tag completion file";
Debug.Fail(title);
Logger.Warn(exception, title);
notificationService.Show(title + $" {path.Name}",
exception.Message, NotificationType.Error);
}, true);
LoadFromFile(path, recreate)
.SafeFireAndForget(
onException: exception =>
{
const string title = "Failed to load tag completion file";
Debug.Fail(title);
Logger.Warn(exception, title);
notificationService.Show(
title + $" {path.Name}",
exception.Message,
NotificationType.Error
);
},
true
);
}
/// <inheritdoc />
public async Task LoadFromFile(FilePath path, bool recreate = false)
{
using var _ = await loadLock.LockAsync();
// Get Blake3 hash of file
var hash = await FileHash.GetBlake3Async(path);
Logger.Trace("Loading tags from {Path} with Blake3 hash {Hash}", path, hash);
// Check for AppData/StabilityMatrix/Temp/Tags/<hash>/*.bin
var tempTagsDir = GlobalConfig.HomeDir.JoinDir("Temp", "Tags");
var hashDir = tempTagsDir.JoinDir(hash);
hashDir.Create();
var headerFile = hashDir.JoinFile("header.bin");
var indexFile = hashDir.JoinFile("index.bin");
entries.Clear();
var timer = Stopwatch.StartNew();
// If directory or any file is missing, rebuild the index
if (recreate || !(hashDir.Exists && headerFile.Exists && indexFile.Exists))
{
Logger.Debug("Creating new index for {Path}", hashDir);
await using var headerStream = headerFile.Info.OpenWrite();
await using var indexStream = indexFile.Info.OpenWrite();
var builder = new IndexBuilder(headerStream, indexStream);
// Parse csv
await using var csvStream = path.Info.OpenRead();
var parser = new TagCsvParser(csvStream);
await foreach (var entry in parser.ParseAsync())
{
if (string.IsNullOrWhiteSpace(entry.Name)) continue;
if (string.IsNullOrWhiteSpace(entry.Name))
continue;
// Add to index
builder.Add(entry.Name);
// Add to local dictionary
entries.Add(entry.Name, entry);
}
await Task.Run(builder.Build);
}
else
@ -149,82 +167,104 @@ public class CompletionProvider : ICompletionProvider
await using var csvStream = path.Info.OpenRead();
var parser = new TagCsvParser(csvStream);
await foreach (var entry in parser.ParseAsync())
{
if (string.IsNullOrWhiteSpace(entry.Name)) continue;
if (string.IsNullOrWhiteSpace(entry.Name))
continue;
// Add to local dictionary
entries.Add(entry.Name, entry);
}
}
searcher = new InMemoryIndexSearcher(headerFile, indexFile);
searcher.Init();
var elapsed = timer.Elapsed;
Logger.Info("Loaded {Count} tags for {Path} in {Time:F2}s", entries.Count, path.Name, elapsed.TotalSeconds);
Logger.Info(
"Loaded {Count} tags for {Path} in {Time:F2}s",
entries.Count,
path.Name,
elapsed.TotalSeconds
);
}
/// <inheritdoc />
public IEnumerable<ICompletionData> GetCompletions(string searchTerm, int itemsCount, bool suggest)
public IEnumerable<ICompletionData> GetCompletions(
TextCompletionRequest completionRequest,
int itemsCount,
bool suggest
)
{
return GetCompletionsImpl_Index(searchTerm, itemsCount, suggest);
}
if (completionRequest.Type == CompletionType.Tag)
{
return GetCompletionTags(completionRequest.Text, itemsCount, suggest);
}
private IEnumerable<ICompletionData> GetCompletionsImpl_Fuzzy(string searchTerm, int itemsCount, bool suggest)
{
var extracted = FuzzySharp.Process
.ExtractTop(searchTerm, entries.Keys);
var results = extracted
.Where(r => r.Score > 40)
.Select(r => r.Value)
.ToImmutableArray();
// No results
if (results.IsEmpty)
if (completionRequest.Type == CompletionType.ExtraNetwork)
{
Logger.Trace("No results for {Term}", searchTerm);
return Array.Empty<ICompletionData>();
return GetCompletionNetworks(
completionRequest.ExtraNetworkTypes,
completionRequest.Text,
itemsCount
);
}
Logger.Trace("Got {Count} results for {Term}", results.Length, searchTerm);
// Get entry for each result
throw new InvalidOperationException();
}
private IEnumerable<ICompletionData> GetCompletionNetworks(
PromptExtraNetworkType networkType,
string searchTerm,
int itemsCount
)
{
var folderTypes = Enum.GetValues(typeof(PromptExtraNetworkType))
.Cast<PromptExtraNetworkType>()
.Where(f => networkType.HasFlag(f))
.Select(network => network.ConvertTo<SharedFolderType>());
var completions = new List<ICompletionData>();
foreach (var item in results)
foreach (var folderType in folderTypes)
{
if (entries.TryGetValue(item, out var entry))
// Get from index service
if (modelIndexService.ModelIndex.TryGetValue(folderType, out var localModels))
{
var entryType = TagTypeExtensions.FromE621(entry.Type.GetValueOrDefault(-1));
completions.Add(new TagCompletionData(entry.Name!, entryType)
{
Priority = entry.Count ?? 0
});
var results =
from model in localModels
where model.FileName.StartsWith(searchTerm, StringComparison.OrdinalIgnoreCase)
select ModelCompletionData.FromLocalModel(model, networkType);
completions.AddRange(results.Take(itemsCount));
}
}
return completions;
}
private IEnumerable<ICompletionData> GetCompletionsImpl_Index(string searchTerm, int itemsCount, bool suggest)
private IEnumerable<ICompletionData> GetCompletionTags(
string searchTerm,
int itemsCount,
bool suggest
)
{
if (searcher is null)
{
throw new InvalidOperationException("Index is not loaded");
}
var timer = Stopwatch.StartNew();
var searchOptions = new SearchOptions
{
Term = searchTerm,
MaxItemCount = itemsCount,
SuggestWhenFoundStartsWith = suggest
};
var result = searcher.Search(searchOptions);
// No results
@ -233,16 +273,16 @@ public class CompletionProvider : ICompletionProvider
Logger.Trace("No results for {Term}", searchTerm);
return Array.Empty<ICompletionData>();
}
// Is null for some reason?
if (result.Items is null)
{
Logger.Warn("Got null results for {Term}", searchTerm);
return Array.Empty<ICompletionData>();
}
Logger.Trace("Got {Count} results for {Term}", result.Items.Length, searchTerm);
// Get entry for each result
var completions = new List<ICompletionData>();
foreach (var item in result.Items)
@ -253,10 +293,14 @@ public class CompletionProvider : ICompletionProvider
completions.Add(new TagCompletionData(entry.Name!, entryType));
}
}
timer.Stop();
Logger.Trace("Completions for {Term} took {Time:F2}ms", searchTerm, timer.Elapsed.TotalMilliseconds);
Logger.Trace(
"Completions for {Term} took {Time:F2}ms",
searchTerm,
timer.Elapsed.TotalMilliseconds
);
return completions;
}
}

2
StabilityMatrix.Avalonia/Models/TagCompletion/ICompletionProvider.cs

@ -32,7 +32,7 @@ public interface ICompletionProvider
/// Returns a list of completion items for the given text.
/// </summary>
public IEnumerable<ICompletionData> GetCompletions(
string searchTerm,
TextCompletionRequest completionRequest,
int itemsCount,
bool suggest
);

11
StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks;
using DynamicData.Binding;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Comfy;
@ -15,11 +16,11 @@ public interface IInferenceClientManager : IDisposable, INotifyPropertyChanged,
[MemberNotNullWhen(true, nameof(Client))]
bool IsConnected { get; }
IReadOnlyCollection<HybridModelFile>? Models { get; set; }
IReadOnlyCollection<HybridModelFile>? VaeModels { get; set; }
IReadOnlyCollection<ComfySampler>? Samplers { get; set; }
IReadOnlyCollection<ComfyUpscaler>? Upscalers { get; set; }
IObservableCollection<HybridModelFile> Models { get; }
IObservableCollection<HybridModelFile> VaeModels { get; }
IObservableCollection<ComfySampler> Samplers { get; }
IObservableCollection<ComfyUpscaler> Upscalers { get; }
Task ConnectAsync();

154
StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

@ -1,13 +1,18 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Reactive.Linq;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using CommunityToolkit.Mvvm.ComponentModel;
using DynamicData;
using DynamicData.Binding;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Avalonia.ViewModels.PackageManager;
using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Inference;
using StabilityMatrix.Core.Models;
@ -35,28 +40,61 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
[MemberNotNullWhen(true, nameof(Client))]
public bool IsConnected => Client is not null;
[ObservableProperty]
private IReadOnlyCollection<HybridModelFile>? models;
private readonly SourceCache<HybridModelFile, string> modelsSource = new(p => p.GetId());
[ObservableProperty]
private IReadOnlyCollection<HybridModelFile>? vaeModels;
[ObservableProperty]
private IReadOnlyCollection<ComfySampler>? samplers;
public IObservableCollection<HybridModelFile> Models { get; } =
new ObservableCollectionExtended<HybridModelFile>();
[ObservableProperty]
private IReadOnlyCollection<ComfyUpscaler>? upscalers;
private readonly SourceCache<HybridModelFile, string> vaeModelsSource = new(p => p.GetId());
private readonly SourceCache<HybridModelFile, string> vaeModelsDefaults = new(p => p.GetId());
public IObservableCollection<HybridModelFile> VaeModels { get; } =
new ObservableCollectionExtended<HybridModelFile>();
private readonly SourceCache<ComfySampler, string> samplersSource = new(p => p.Name);
public IObservableCollection<ComfySampler> Samplers { get; } =
new ObservableCollectionExtended<ComfySampler>();
private readonly SourceCache<ComfyUpscaler, string> modelUpscalersSource = new(p => p.Name);
private readonly SourceCache<ComfyUpscaler, string> latentUpscalersSource = new(p => p.Name);
public IObservableCollection<ComfyUpscaler> Upscalers { get; } =
new ObservableCollectionExtended<ComfyUpscaler>();
public InferenceClientManager(
ILogger<InferenceClientManager> logger,
ILogger<InferenceClientManager> logger,
IApiFactory apiFactory,
IModelIndexService modelIndexService)
IModelIndexService modelIndexService
)
{
this.logger = logger;
this.apiFactory = apiFactory;
this.modelIndexService = modelIndexService;
modelsSource.Connect().DeferUntilLoaded().Bind(Models).Subscribe();
vaeModelsDefaults.AddOrUpdate(HybridModelFile.Default);
ClearSharedProperties();
vaeModelsDefaults
.Connect()
.Or(vaeModelsSource.Connect())
.DeferUntilLoaded()
.Bind(VaeModels)
.Subscribe();
samplersSource.Connect().DeferUntilLoaded().Bind(Samplers).Subscribe();
latentUpscalersSource
.Connect()
.Or(modelUpscalersSource.Connect())
.DeferUntilLoaded()
.Bind(Upscalers)
.Subscribe();
ResetSharedProperties();
}
private async Task LoadSharedPropertiesAsync()
@ -64,63 +102,85 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
if (!IsConnected)
throw new InvalidOperationException("Client is not connected");
var modelNames = await Client.GetModelNamesAsync();
Models = modelNames?.Select(HybridModelFile.FromRemote).ToImmutableArray();
if (await Client.GetModelNamesAsync() is { } modelNames)
{
modelsSource.EditDiff(
modelNames.Select(HybridModelFile.FromRemote),
HybridModelFile.Comparer
);
}
// Fetch sampler names from KSampler node
var samplerNames = await Client.GetSamplerNamesAsync();
Samplers = samplerNames?.Select(name => new ComfySampler(name)).ToImmutableArray();
if (await Client.GetSamplerNamesAsync() is { } samplerNames)
{
samplersSource.EditDiff(
samplerNames.Select(name => new ComfySampler(name)),
ComfySampler.Comparer
);
}
// Upscalers is latent and esrgan combined
var upscalerBuilder = ImmutableArray.CreateBuilder<ComfyUpscaler>();
// Add latent upscale methods from LatentUpscale node
var latentUpscalerNames = await Client.GetNodeOptionNamesAsync(
"LatentUpscale",
"upscale_method");
if (latentUpscalerNames is not null)
if (
await Client.GetNodeOptionNamesAsync("LatentUpscale", "upscale_method") is
{ } latentUpscalerNames
)
{
upscalerBuilder.AddRange(latentUpscalerNames.Select(
s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)));
latentUpscalersSource.EditDiff(
latentUpscalerNames.Select(s => new ComfyUpscaler(s, ComfyUpscalerType.Latent)),
ComfyUpscaler.Comparer
);
logger.LogTrace("Loaded latent upscale methods: {@Upscalers}", latentUpscalerNames);
}
logger.LogTrace("Loaded latent upscale methods: {@Upscalers}", latentUpscalerNames);
// Add Model upscale methods
var modelUpscalerNames = await Client.GetNodeOptionNamesAsync(
"UpscaleModelLoader",
"model_name");
if (modelUpscalerNames is not null)
if (
await Client.GetNodeOptionNamesAsync("UpscaleModelLoader", "model_name") is
{ } modelUpscalerNames
)
{
upscalerBuilder.AddRange(modelUpscalerNames.Select(
s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)));
modelUpscalersSource.EditDiff(modelUpscalerNames.Select(
s => new ComfyUpscaler(s, ComfyUpscalerType.ESRGAN)), ComfyUpscaler.Comparer);
logger.LogTrace("Loaded model upscale methods: {@Upscalers}", modelUpscalerNames);
}
logger.LogTrace("Loaded model upscale methods: {@Upscalers}", modelUpscalerNames);
Upscalers = upscalerBuilder.ToImmutable();
}
/// <summary>
/// Clears shared properties and sets them to local defaults
/// </summary>
private void ClearSharedProperties()
private void ResetSharedProperties()
{
// Load local models
modelIndexService.GetModelsOfType(SharedFolderType.StableDiffusion)
modelIndexService
.GetModelsOfType(SharedFolderType.StableDiffusion)
.ContinueWith(task =>
{
Models = task.Result.Select(HybridModelFile.FromLocal).ToImmutableArray();
}).SafeFireAndForget();
modelsSource.EditDiff(
task.Result.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
})
.SafeFireAndForget();
// Load local VAE models
modelIndexService.GetModelsOfType(SharedFolderType.VAE)
modelIndexService
.GetModelsOfType(SharedFolderType.VAE)
.ContinueWith(task =>
{
VaeModels = task.Result.Select(HybridModelFile.FromLocal).ToImmutableArray();
}).SafeFireAndForget();
Samplers = ComfySampler.Defaults;
Upscalers = null;
vaeModelsSource.EditDiff(
task.Result.Select(HybridModelFile.FromLocal),
HybridModelFile.Comparer
);
})
.SafeFireAndForget();
samplersSource.EditDiff(ComfySampler.Defaults, ComfySampler.Comparer);
latentUpscalersSource.EditDiff(ComfyUpscaler.Defaults, ComfyUpscaler.Comparer);
modelUpscalersSource.Clear();
}
public async Task ConnectAsync()
@ -164,7 +224,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
await Client.CloseAsync();
Client = null;
ClearSharedProperties();
ResetSharedProperties();
}
public void Dispose()

26
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -152,7 +152,25 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
}
);
var checkpointVae = checkpointLoader.GetOutput<VAENodeConnection>(2);
// Either use checkpoint VAE or custom VAE
VAENodeConnection vaeSource;
if (modelCard is {IsVaeSelectionEnabled: true, SelectedVae.IsDefault: false})
{
// Use custom VAE
// Add a loader
var vaeLoader =
prompt.AddNamedNode(ComfyNodeBuilder.VAELoader("VAELoader", modelCard.SelectedVae.FileName));
// Set as source
vaeSource = vaeLoader.Output;
}
else
{
// Use checkpoint VAE
vaeSource = checkpointLoader.GetOutput<VAENodeConnection>(2);
}
var emptyLatentImage = prompt.AddNamedNode(
new NamedComfyNode("EmptyLatentImage")
@ -219,7 +237,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
Inputs = new Dictionary<string, object?>
{
["samples"] = lastLatent,
["vae"] = checkpointLoader.GetOutput(2)
["vae"] = vaeSource
}
}
);
@ -260,7 +278,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
{
// Otherwise upscale the latent image
hiresLatent = builder.Group_UpscaleToLatent("HiresFix",
lastLatent, checkpointVae, selectedUpscaler, hiresWidth, hiresHeight).Output;
lastLatent, vaeSource, selectedUpscaler, hiresWidth, hiresHeight).Output;
}
var hiresSampler = prompt.AddNamedNode(
@ -301,7 +319,7 @@ public partial class InferenceTextToImageViewModel : InferenceTabViewModelBase
// Build group
var postUpscaleGroup = builder.Group_UpscaleToImage("PostUpscale",
lastLatent, checkpointVae, postUpscalerCard.SelectedUpscaler!.Value,
lastLatent, vaeSource, postUpscalerCard.SelectedUpscaler!.Value,
upscaleWidth, upscaleHeight);
// Remove the original vae decoder

Loading…
Cancel
Save