From 6d5f47ddaa42e19295c000d2016117f7c4ad5c68 Mon Sep 17 00:00:00 2001 From: Ionite Date: Mon, 5 Jun 2023 04:11:59 -0400 Subject: [PATCH] Download validation fixes --- .../MockCheckpointBrowserViewModel.cs | 4 +- StabilityMatrix/Services/DownloadService.cs | 15 ++++--- .../CheckpointBrowserCardViewModel.cs | 42 +++++++++++++++++-- .../ViewModels/CheckpointBrowserViewModel.cs | 10 ++--- 4 files changed, 54 insertions(+), 17 deletions(-) diff --git a/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs b/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs index fdaccfed..94940ebf 100644 --- a/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs +++ b/StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs @@ -8,11 +8,11 @@ namespace StabilityMatrix.DesignData; [DesignOnly(true)] public class MockCheckpointBrowserViewModel : CheckpointBrowserViewModel { - public MockCheckpointBrowserViewModel() : base(null!, null!) + public MockCheckpointBrowserViewModel() : base(null!, null!, null!) { ModelCards = new ObservableCollection { - new (null!, null!) + new (null!, null!, null!) { CivitModel = new() { diff --git a/StabilityMatrix/Services/DownloadService.cs b/StabilityMatrix/Services/DownloadService.cs index 6669de51..395d8f53 100644 --- a/StabilityMatrix/Services/DownloadService.cs +++ b/StabilityMatrix/Services/DownloadService.cs @@ -5,6 +5,7 @@ using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using Polly.Contrib.WaitAndRetry; using StabilityMatrix.Models; namespace StabilityMatrix.Services; @@ -29,26 +30,28 @@ public class DownloadService : IDownloadService await using var file = new FileStream(downloadLocation, FileMode.Create, FileAccess.Write, FileShare.None); long contentLength = 0; - var retryCount = 0; - + var response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); contentLength = response.Content.Headers.ContentLength ?? 0; - while (contentLength == 0 && retryCount++ < 5) + var delays = Backoff.DecorrelatedJitterBackoffV2( + TimeSpan.FromMilliseconds(50), retryCount: 3); + + foreach (var delay in delays) { + if (contentLength > 0) break; logger.LogDebug("Retrying get-headers for content-length"); - Thread.Sleep(50); + await Task.Delay(delay); response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); contentLength = response.Content.Headers.ContentLength ?? 0; } - var isIndeterminate = contentLength == 0; await using var stream = await response.Content.ReadAsStreamAsync(); var totalBytesRead = 0L; + var buffer = new byte[bufferSize]; while (true) { - var buffer = new byte[bufferSize]; var bytesRead = await stream.ReadAsync(buffer); if (bytesRead == 0) break; await file.WriteAsync(buffer.AsMemory(0, bytesRead)); diff --git a/StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs b/StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs index c14cd63e..d07d99dc 100644 --- a/StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs +++ b/StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs @@ -1,10 +1,11 @@ using System; using System.Diagnostics; using System.IO; -using System.Net.Mime; using System.Threading.Tasks; using System.Windows; using CommunityToolkit.Mvvm.Input; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Helper; using StabilityMatrix.Models; using StabilityMatrix.Models.Api; using StabilityMatrix.Services; @@ -14,14 +15,16 @@ namespace StabilityMatrix.ViewModels; public partial class CheckpointBrowserCardViewModel : ProgressViewModel { private readonly IDownloadService downloadService; + private readonly ISnackbarService snackbarService; public CivitModel CivitModel { get; init; } public override Visibility ProgressVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed; public override Visibility TextVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed; - public CheckpointBrowserCardViewModel(CivitModel civitModel, IDownloadService downloadService) + public CheckpointBrowserCardViewModel(CivitModel civitModel, IDownloadService downloadService, ISnackbarService snackbarService) { this.downloadService = downloadService; + this.snackbarService = snackbarService; CivitModel = civitModel; } @@ -41,18 +44,49 @@ public partial class CheckpointBrowserCardViewModel : ProgressViewModel Text = "Downloading..."; var latestModelFile = model.ModelVersions[0].Files[0]; + var fileExpectedSha256 = latestModelFile.Hashes.SHA256; var downloadPath = Path.Combine(SharedFolders.SharedFoldersPath, SharedFolders.SharedFolderTypeToName(model.Type.ToSharedFolderType()), latestModelFile.Name); - var progress = new Progress(progress => + var downloadProgress = new Progress(progress => { Value = progress.Percentage; Text = $"Importing... {progress.Percentage}%"; }); - await downloadService.DownloadToFileAsync(latestModelFile.DownloadUrl, downloadPath, progress: progress); + await downloadService.DownloadToFileAsync(latestModelFile.DownloadUrl, downloadPath, progress: downloadProgress); + + // When sha256 is available, validate the downloaded file + if (!string.IsNullOrEmpty(fileExpectedSha256)) + { + var hashProgress = new Progress(progress => + { + Value = progress.Percentage; + Text = $"Validating... {progress.Percentage}%"; + }); + var sha256 = await FileHash.GetSha256Async(downloadPath, hashProgress); + if (sha256 != fileExpectedSha256.ToLowerInvariant()) + { + Text = "Import Failed!"; + DelayedClearProgress(TimeSpan.FromSeconds(800)); + await snackbarService.ShowSnackbarAsync( + "This may be caused by network or server issues from CivitAI, please try again in a few minutes.", + "Download failed hash validation", LogLevel.Warning); + return; + } + } Text = "Import complete!"; Value = 100; + DelayedClearProgress(TimeSpan.FromMilliseconds(800)); + } + + private void DelayedClearProgress(TimeSpan delay) + { + Task.Delay(delay).ContinueWith(_ => + { + Text = string.Empty; + Value = 0; + }); } } diff --git a/StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs b/StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs index 91487f11..1dd16896 100644 --- a/StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs +++ b/StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs @@ -1,15 +1,13 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; -using System.Diagnostics; -using System.IO; using System.Linq; using System.Threading.Tasks; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using NLog; using StabilityMatrix.Api; -using StabilityMatrix.Models; +using StabilityMatrix.Helper; using StabilityMatrix.Models.Api; using StabilityMatrix.Services; @@ -20,6 +18,7 @@ public partial class CheckpointBrowserViewModel : ObservableObject private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private readonly ICivitApi civitApi; private readonly IDownloadService downloadService; + private readonly ISnackbarService snackbarService; private const int MaxModelsPerPage = 14; [ObservableProperty] private string? searchQuery; @@ -38,10 +37,11 @@ public partial class CheckpointBrowserViewModel : ObservableObject public IEnumerable AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast(); public IEnumerable AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast(); - public CheckpointBrowserViewModel(ICivitApi civitApi, IDownloadService downloadService) + public CheckpointBrowserViewModel(ICivitApi civitApi, IDownloadService downloadService, ISnackbarService snackbarService) { this.civitApi = civitApi; this.downloadService = downloadService; + this.snackbarService = snackbarService; SelectedPeriod = CivitPeriod.Month; SortMode = CivitSortMode.HighestRated; @@ -76,7 +76,7 @@ public partial class CheckpointBrowserViewModel : ObservableObject CanGoToPreviousPage = CurrentPageNumber > 1; CanGoToNextPage = CurrentPageNumber < TotalPages; ModelCards = new ObservableCollection(models.Items.Select( - m => new CheckpointBrowserCardViewModel(m, downloadService))); + m => new CheckpointBrowserCardViewModel(m, downloadService, snackbarService))); ShowMainLoadingSpinner = false; Logger.Debug($"Found {models.Items.Length} models");