using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Net.Http; using System.Threading.Tasks; using AsyncAwaitBestPractices; using Avalonia; using Avalonia.Controls; using Avalonia.Controls.Notifications; using Avalonia.Media.Imaging; using Avalonia.Platform; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using NLog; using Octokit; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.Views.Dialogs; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Services; using Notification = Avalonia.Controls.Notifications.Notification; namespace StabilityMatrix.Avalonia.ViewModels; public partial class CheckpointBrowserCardViewModel : ProgressViewModel { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly IDownloadService downloadService; private readonly ISettingsManager settingsManager; private readonly ServiceManager dialogFactory; private readonly INotificationService notificationService; public CivitModel CivitModel { get; init; } public override bool IsTextVisible => Value > 0; [ObservableProperty] private Bitmap? cardImage; [ObservableProperty] private bool isImporting; public CheckpointBrowserCardViewModel( CivitModel civitModel, IDownloadService downloadService, ISettingsManager settingsManager, ServiceManager dialogFactory, INotificationService notificationService, Bitmap? fixedImage = null) { this.downloadService = downloadService; this.settingsManager = settingsManager; this.dialogFactory = dialogFactory; this.notificationService = notificationService; CivitModel = civitModel; if (fixedImage != null) { CardImage = fixedImage; return; } UpdateImage().SafeFireAndForget(); // Update image when nsfw setting changes settingsManager.RegisterPropertyChangedHandler( s => s.ModelBrowserNsfwEnabled, _ => UpdateImage().SafeFireAndForget()); } // Choose and load image based on nsfw setting private async Task UpdateImage() { var nsfwEnabled = settingsManager.Settings.ModelBrowserNsfwEnabled; var version = CivitModel.ModelVersions?.FirstOrDefault(); var images = version?.Images; var image = images?.FirstOrDefault(image => nsfwEnabled || image.Nsfw == "None"); if (image != null) { var imageStream = await downloadService.GetImageStreamFromUrl(image.Url); Dispatcher.UIThread.Post(() => { CardImage = new Bitmap(imageStream); }); return; } var assetStream = AssetLoader.Open(new Uri("avares://StabilityMatrix.Avalonia/Assets/noimage.png")); // Otherwise Default image Dispatcher.UIThread.Post(() => { CardImage = new Bitmap(assetStream); }); } // On any mode changes, update the image private void OnNsfwModeChanged(object? sender, bool value) { UpdateImage().SafeFireAndForget(); } [RelayCommand] private void OpenModel() { ProcessRunner.OpenUrl($"https://civitai.com/models/{CivitModel.Id}"); } [RelayCommand] private async Task Import(CivitModel model) { await DoImport(model); } [RelayCommand] private async Task ShowVersionDialog(CivitModel model) { var versions = model.ModelVersions; if (versions is null || versions.Count == 0) { notificationService.Show(new Notification("Model has no versions available", "This model has no versions available for download", NotificationType.Warning)); return; } var dialog = new ContentDialog { Title = model.Name, IsPrimaryButtonEnabled = false, IsSecondaryButtonEnabled = false, }; var viewModel = dialogFactory.Get(); viewModel.Dialog = dialog; viewModel.Versions = versions; dialog.Content = new SelectModelVersionDialog { DataContext = viewModel }; var result = await dialog.ShowAsync(); if (result != ContentDialogResult.Primary) { return; } var selectedVersion = viewModel?.SelectedVersion; var selectedFile = viewModel?.SelectedFile; await Task.Delay(100); await DoImport(model, selectedVersion, selectedFile); } private async Task DoImport(CivitModel model, CivitModelVersion? selectedVersion = null, CivitFile? selectedFile = null) { IsImporting = true; Text = "Downloading..."; // Holds files to be deleted on errors var filesForCleanup = new HashSet(); // Set Text when exiting, finally block will set 100 and delay clear progress try { // Get latest version var modelVersion = selectedVersion ?? model.ModelVersions?.FirstOrDefault(); if (modelVersion is null) { notificationService.Show(new Notification("Model has no versions available", "This model has no versions available for download", NotificationType.Warning)); Text = "Unable to Download"; return; } // Get latest version file var modelFile = selectedFile ?? modelVersion.Files?.FirstOrDefault(x => x.Type == CivitFileType.Model); if (modelFile is null) { notificationService.Show(new Notification("Model has no files available", "This model has no files available for download", NotificationType.Warning)); Text = "Unable to Download"; return; } var downloadFolder = Path.Combine(settingsManager.ModelsDirectory, model.Type.ConvertTo().GetStringValue()); // Folders might be missing if user didn't install any packages yet Directory.CreateDirectory(downloadFolder); var downloadPath = Path.GetFullPath(Path.Combine(downloadFolder, modelFile.Name)); filesForCleanup.Add(downloadPath); // Do the download var downloadTask = downloadService.DownloadToFileAsync(modelFile.DownloadUrl, downloadPath, new Progress(report => { if (Math.Abs(report.Percentage - Value) > 0.1) { Dispatcher.UIThread.Post(() => { Value = report.Percentage; Text = $"Downloading... {report.Percentage}%"; }); } })); var downloadResult = await notificationService.TryAsync(downloadTask, "Could not download file"); // Failed download handling if (downloadResult.Exception is not null) { // For exceptions other than ApiException or TaskCanceledException, log error var logLevel = downloadResult.Exception switch { HttpRequestException or ApiException or TaskCanceledException => LogLevel.Warn, _ => LogLevel.Error }; Logger.Log(logLevel, downloadResult.Exception, "Error during model download"); Text = "Download Failed"; return; } // When sha256 is available, validate the downloaded file var fileExpectedSha256 = modelFile.Hashes.SHA256; if (!string.IsNullOrEmpty(fileExpectedSha256)) { var hashProgress = new Progress(progress => { Value = progress.Percentage; Text = $"Validating... {progress.Percentage}%"; }); var sha256 = await FileHash.GetSha256Async(downloadPath, hashProgress); if (sha256 != fileExpectedSha256.ToLowerInvariant()) { Text = "Import Failed!"; DelayedClearProgress(TimeSpan.FromMilliseconds(800)); notificationService.Show(new Notification("Download failed hash validation", "This may be caused by network or server issues from CivitAI, please try again in a few minutes.", NotificationType.Error)); Text = "Download Failed"; return; } notificationService.Show(new Notification("Import complete", $"{model.Type} {model.Name} imported successfully!", NotificationType.Success)); } IsIndeterminate = true; // Save connected model info var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Name); var modelInfo = new ConnectedModelInfo(CivitModel, modelVersion, modelFile, DateTime.UtcNow); var modelInfoPath = Path.GetFullPath(Path.Combine( downloadFolder, modelFileName + ConnectedModelInfo.FileExtension)); filesForCleanup.Add(modelInfoPath); await modelInfo.SaveJsonToDirectory(downloadFolder, modelFileName); // If available, save a model image if (modelVersion.Images != null && modelVersion.Images.Any()) { var image = modelVersion.Images[0]; var imageExtension = Path.GetExtension(image.Url).TrimStart('.'); if (imageExtension is "jpg" or "jpeg" or "png") { var imageDownloadPath = Path.GetFullPath(Path.Combine(downloadFolder, $"{modelFileName}.preview.{imageExtension}")); filesForCleanup.Add(imageDownloadPath); var imageTask = downloadService.DownloadToFileAsync(image.Url, imageDownloadPath); await notificationService.TryAsync(imageTask, "Could not download preview image"); } } // Successful - clear cleanup list filesForCleanup.Clear(); Text = "Import complete!"; } catch (Exception e) { Debug.WriteLine(e); } finally { foreach (var file in filesForCleanup.Where(file => file.Exists)) { file.Delete(); Logger.Info($"Download cleanup: Deleted file {file}"); } IsIndeterminate = false; Value = 100; DelayedClearProgress(TimeSpan.FromMilliseconds(800)); } } private void DelayedClearProgress(TimeSpan delay) { Task.Delay(delay).ContinueWith(_ => { Text = string.Empty; Value = 0; IsImporting = false; }); } }