using System; using System.Collections.Generic; using System.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using DynamicData.Binding; using NLog; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Models; using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Services; using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView; #pragma warning disable CS0657 // Not a valid attribute location for this declaration namespace StabilityMatrix.Avalonia.ViewModels.Inference; [View(typeof(InferenceTextToImageView), IsPersistent = true)] [ManagedService] [Transient] public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, IParametersLoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly INotificationService notificationService; private readonly IModelIndexService modelIndexService; [JsonIgnore] public StackCardViewModel StackCardViewModel { get; } [JsonPropertyName("Modules")] public StackEditableCardViewModel ModulesCardViewModel { get; } [JsonPropertyName("Model")] public ModelCardViewModel ModelCardViewModel { get; } [JsonPropertyName("Sampler")] public SamplerCardViewModel SamplerCardViewModel { get; } [JsonPropertyName("Prompt")] public PromptCardViewModel PromptCardViewModel { get; } [JsonPropertyName("BatchSize")] public BatchSizeCardViewModel BatchSizeCardViewModel { get; } [JsonPropertyName("Seed")] public SeedCardViewModel SeedCardViewModel { get; } public InferenceTextToImageViewModel( INotificationService notificationService, IInferenceClientManager inferenceClientManager, ISettingsManager settingsManager, ServiceManager vmFactory, IModelIndexService modelIndexService, RunningPackageService runningPackageService ) : base(vmFactory, inferenceClientManager, notificationService, settingsManager, runningPackageService) { this.notificationService = notificationService; this.modelIndexService = modelIndexService; // Get sub view models from service manager SeedCardViewModel = vmFactory.Get(); SeedCardViewModel.GenerateNewSeed(); ModelCardViewModel = vmFactory.Get(); SamplerCardViewModel = vmFactory.Get(samplerCard => { samplerCard.IsDimensionsEnabled = true; samplerCard.IsCfgScaleEnabled = true; samplerCard.IsSamplerSelectionEnabled = true; samplerCard.IsSchedulerSelectionEnabled = true; samplerCard.DenoiseStrength = 1.0d; }); PromptCardViewModel = vmFactory.Get(); BatchSizeCardViewModel = vmFactory.Get(); ModulesCardViewModel = vmFactory.Get(modulesCard => { modulesCard.AvailableModules = new[] { typeof(HiresFixModule), typeof(UpscalerModule), typeof(SaveImageModule) }; modulesCard.DefaultModules = new[] { typeof(HiresFixModule), typeof(UpscalerModule) }; modulesCard.InitializeDefaults(); }); StackCardViewModel = vmFactory.Get(); StackCardViewModel.AddCards( ModelCardViewModel, SamplerCardViewModel, ModulesCardViewModel, SeedCardViewModel, BatchSizeCardViewModel ); // When refiner is provided in model card, enable for sampler ModelCardViewModel .WhenPropertyChanged(x => x.IsRefinerSelectionEnabled) .Subscribe(e => { SamplerCardViewModel.IsRefinerStepsEnabled = e.Sender is { IsRefinerSelectionEnabled: true, SelectedRefiner: not null }; }); } /// protected override void BuildPrompt(BuildPromptEventArgs args) { base.BuildPrompt(args); var builder = args.Builder; // Load constants builder.Connections.Seed = args.SeedOverride switch { { } seed => Convert.ToUInt64(seed), _ => Convert.ToUInt64(SeedCardViewModel.Seed) }; var applyArgs = args.ToModuleApplyStepEventArgs(); BatchSizeCardViewModel.ApplyStep(applyArgs); // Load models ModelCardViewModel.ApplyStep(applyArgs); // Setup empty latent builder.SetupEmptyLatentSource( SamplerCardViewModel.Width, SamplerCardViewModel.Height, BatchSizeCardViewModel.BatchSize, BatchSizeCardViewModel.IsBatchIndexEnabled ? BatchSizeCardViewModel.BatchIndex : null ); // Prompts and loras PromptCardViewModel.ApplyStep(applyArgs); // Setup Sampler and Refiner if enabled SamplerCardViewModel.ApplyStep(applyArgs); // Hires fix if enabled foreach (var module in ModulesCardViewModel.Cards.OfType()) { module.ApplyStep(applyArgs); } applyArgs.InvokeAllPreOutputActions(); builder.SetupOutputImage(); } /// protected override IEnumerable GetInputImages() { var samplerImages = SamplerCardViewModel .ModulesCardViewModel.Cards.OfType() .SelectMany(m => m.GetInputImages()); var moduleImages = ModulesCardViewModel .Cards.OfType() .SelectMany(m => m.GetInputImages()); return samplerImages.Concat(moduleImages); } /// protected override async Task GenerateImageImpl( GenerateOverrides overrides, CancellationToken cancellationToken ) { // Validate the prompts if (!await PromptCardViewModel.ValidatePrompts()) return; if (!await ModelCardViewModel.ValidateModel()) return; if (!await CheckClientConnectedWithPrompt() || !ClientManager.IsConnected) return; // If enabled, randomize the seed var seedCard = StackCardViewModel.GetCard(); if (overrides is not { UseCurrentSeed: true } && seedCard.IsRandomizeEnabled) { seedCard.GenerateNewSeed(); } var batches = BatchSizeCardViewModel.BatchCount; var batchArgs = new List(); for (var i = 0; i < batches; i++) { var seed = seedCard.Seed + i; var buildPromptArgs = new BuildPromptEventArgs { Overrides = overrides, SeedOverride = seed }; BuildPrompt(buildPromptArgs); var generationArgs = new ImageGenerationEventArgs { Client = ClientManager.Client, Nodes = buildPromptArgs.Builder.ToNodeDictionary(), OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(), Parameters = SaveStateToParameters(new GenerationParameters()), Project = InferenceProjectDocument.FromLoadable(this), FilesToTransfer = buildPromptArgs.FilesToTransfer, BatchIndex = i, // Only clear output images on the first batch ClearOutputImages = i == 0 }; batchArgs.Add(generationArgs); } // Run batches foreach (var args in batchArgs) { await RunGeneration(args, cancellationToken); } } /// public void LoadStateFromParameters(GenerationParameters parameters) { PromptCardViewModel.LoadStateFromParameters(parameters); SamplerCardViewModel.LoadStateFromParameters(parameters); ModelCardViewModel.LoadStateFromParameters(parameters); SeedCardViewModel.Seed = Convert.ToInt64(parameters.Seed); if (Math.Abs(SamplerCardViewModel.DenoiseStrength - 1.0d) > 0.01d) { SamplerCardViewModel.DenoiseStrength = 1.0d; } } /// public GenerationParameters SaveStateToParameters(GenerationParameters parameters) { parameters = PromptCardViewModel.SaveStateToParameters(parameters); parameters = SamplerCardViewModel.SaveStateToParameters(parameters); parameters = ModelCardViewModel.SaveStateToParameters(parameters); parameters.Seed = (ulong)SeedCardViewModel.Seed; return parameters; } // Deserialization overrides public override void LoadStateFromJsonObject(JsonObject state, int version) { // For v2 and below, do migration if (version <= 2) { ModulesCardViewModel.Clear(); // Add by default the original cards as steps - HiresFix, Upscaler ModulesCardViewModel.AddModule(module => { module.IsEnabled = state.GetPropertyValueOrDefault("IsHiresFixEnabled"); if (state.TryGetPropertyValue("HiresSampler", out var hiresSamplerState)) { module .GetCard() .LoadStateFromJsonObject(hiresSamplerState!.AsObject()); } if (state.TryGetPropertyValue("HiresUpscaler", out var hiresUpscalerState)) { module .GetCard() .LoadStateFromJsonObject(hiresUpscalerState!.AsObject()); } }); ModulesCardViewModel.AddModule(module => { module.IsEnabled = state.GetPropertyValueOrDefault("IsUpscaleEnabled"); if (state.TryGetPropertyValue("Upscaler", out var upscalerState)) { module .GetCard() .LoadStateFromJsonObject(upscalerState!.AsObject()); } }); // Add FreeU to sampler SamplerCardViewModel.ModulesCardViewModel.AddModule(module => { module.IsEnabled = state.GetPropertyValueOrDefault("IsFreeUEnabled"); }); } base.LoadStateFromJsonObject(state); } }