Browse Source

Download validation fixes

pull/5/head
Ionite 1 year ago
parent
commit
6d5f47ddaa
No known key found for this signature in database
  1. 4
      StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs
  2. 15
      StabilityMatrix/Services/DownloadService.cs
  3. 42
      StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs
  4. 10
      StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs

4
StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs

@ -8,11 +8,11 @@ namespace StabilityMatrix.DesignData;
[DesignOnly(true)] [DesignOnly(true)]
public class MockCheckpointBrowserViewModel : CheckpointBrowserViewModel public class MockCheckpointBrowserViewModel : CheckpointBrowserViewModel
{ {
public MockCheckpointBrowserViewModel() : base(null!, null!) public MockCheckpointBrowserViewModel() : base(null!, null!, null!)
{ {
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel> ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>
{ {
new (null!, null!) new (null!, null!, null!)
{ {
CivitModel = new() CivitModel = new()
{ {

15
StabilityMatrix/Services/DownloadService.cs

@ -5,6 +5,7 @@ using System.Net.Http.Headers;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Polly.Contrib.WaitAndRetry;
using StabilityMatrix.Models; using StabilityMatrix.Models;
namespace StabilityMatrix.Services; namespace StabilityMatrix.Services;
@ -29,26 +30,28 @@ public class DownloadService : IDownloadService
await using var file = new FileStream(downloadLocation, FileMode.Create, FileAccess.Write, FileShare.None); await using var file = new FileStream(downloadLocation, FileMode.Create, FileAccess.Write, FileShare.None);
long contentLength = 0; long contentLength = 0;
var retryCount = 0;
var response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); var response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead);
contentLength = response.Content.Headers.ContentLength ?? 0; contentLength = response.Content.Headers.ContentLength ?? 0;
while (contentLength == 0 && retryCount++ < 5) var delays = Backoff.DecorrelatedJitterBackoffV2(
TimeSpan.FromMilliseconds(50), retryCount: 3);
foreach (var delay in delays)
{ {
if (contentLength > 0) break;
logger.LogDebug("Retrying get-headers for content-length"); logger.LogDebug("Retrying get-headers for content-length");
Thread.Sleep(50); await Task.Delay(delay);
response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead);
contentLength = response.Content.Headers.ContentLength ?? 0; contentLength = response.Content.Headers.ContentLength ?? 0;
} }
var isIndeterminate = contentLength == 0; var isIndeterminate = contentLength == 0;
await using var stream = await response.Content.ReadAsStreamAsync(); await using var stream = await response.Content.ReadAsStreamAsync();
var totalBytesRead = 0L; var totalBytesRead = 0L;
var buffer = new byte[bufferSize];
while (true) while (true)
{ {
var buffer = new byte[bufferSize];
var bytesRead = await stream.ReadAsync(buffer); var bytesRead = await stream.ReadAsync(buffer);
if (bytesRead == 0) break; if (bytesRead == 0) break;
await file.WriteAsync(buffer.AsMemory(0, bytesRead)); await file.WriteAsync(buffer.AsMemory(0, bytesRead));

42
StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs

@ -1,10 +1,11 @@
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Net.Mime;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Windows; using System.Windows;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Helper;
using StabilityMatrix.Models; using StabilityMatrix.Models;
using StabilityMatrix.Models.Api; using StabilityMatrix.Models.Api;
using StabilityMatrix.Services; using StabilityMatrix.Services;
@ -14,14 +15,16 @@ namespace StabilityMatrix.ViewModels;
public partial class CheckpointBrowserCardViewModel : ProgressViewModel public partial class CheckpointBrowserCardViewModel : ProgressViewModel
{ {
private readonly IDownloadService downloadService; private readonly IDownloadService downloadService;
private readonly ISnackbarService snackbarService;
public CivitModel CivitModel { get; init; } public CivitModel CivitModel { get; init; }
public override Visibility ProgressVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed; public override Visibility ProgressVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed;
public override Visibility TextVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed; public override Visibility TextVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed;
public CheckpointBrowserCardViewModel(CivitModel civitModel, IDownloadService downloadService) public CheckpointBrowserCardViewModel(CivitModel civitModel, IDownloadService downloadService, ISnackbarService snackbarService)
{ {
this.downloadService = downloadService; this.downloadService = downloadService;
this.snackbarService = snackbarService;
CivitModel = civitModel; CivitModel = civitModel;
} }
@ -41,18 +44,49 @@ public partial class CheckpointBrowserCardViewModel : ProgressViewModel
Text = "Downloading..."; Text = "Downloading...";
var latestModelFile = model.ModelVersions[0].Files[0]; var latestModelFile = model.ModelVersions[0].Files[0];
var fileExpectedSha256 = latestModelFile.Hashes.SHA256;
var downloadPath = Path.Combine(SharedFolders.SharedFoldersPath, var downloadPath = Path.Combine(SharedFolders.SharedFoldersPath,
SharedFolders.SharedFolderTypeToName(model.Type.ToSharedFolderType()), latestModelFile.Name); SharedFolders.SharedFolderTypeToName(model.Type.ToSharedFolderType()), latestModelFile.Name);
var progress = new Progress<ProgressReport>(progress => var downloadProgress = new Progress<ProgressReport>(progress =>
{ {
Value = progress.Percentage; Value = progress.Percentage;
Text = $"Importing... {progress.Percentage}%"; Text = $"Importing... {progress.Percentage}%";
}); });
await downloadService.DownloadToFileAsync(latestModelFile.DownloadUrl, downloadPath, progress: progress); await downloadService.DownloadToFileAsync(latestModelFile.DownloadUrl, downloadPath, progress: downloadProgress);
// When sha256 is available, validate the downloaded file
if (!string.IsNullOrEmpty(fileExpectedSha256))
{
var hashProgress = new Progress<ProgressReport>(progress =>
{
Value = progress.Percentage;
Text = $"Validating... {progress.Percentage}%";
});
var sha256 = await FileHash.GetSha256Async(downloadPath, hashProgress);
if (sha256 != fileExpectedSha256.ToLowerInvariant())
{
Text = "Import Failed!";
DelayedClearProgress(TimeSpan.FromSeconds(800));
await snackbarService.ShowSnackbarAsync(
"This may be caused by network or server issues from CivitAI, please try again in a few minutes.",
"Download failed hash validation", LogLevel.Warning);
return;
}
}
Text = "Import complete!"; Text = "Import complete!";
Value = 100; Value = 100;
DelayedClearProgress(TimeSpan.FromMilliseconds(800));
}
private void DelayedClearProgress(TimeSpan delay)
{
Task.Delay(delay).ContinueWith(_ =>
{
Text = string.Empty;
Value = 0;
});
} }
} }

10
StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs

@ -1,15 +1,13 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using NLog; using NLog;
using StabilityMatrix.Api; using StabilityMatrix.Api;
using StabilityMatrix.Models; using StabilityMatrix.Helper;
using StabilityMatrix.Models.Api; using StabilityMatrix.Models.Api;
using StabilityMatrix.Services; using StabilityMatrix.Services;
@ -20,6 +18,7 @@ public partial class CheckpointBrowserViewModel : ObservableObject
private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ICivitApi civitApi; private readonly ICivitApi civitApi;
private readonly IDownloadService downloadService; private readonly IDownloadService downloadService;
private readonly ISnackbarService snackbarService;
private const int MaxModelsPerPage = 14; private const int MaxModelsPerPage = 14;
[ObservableProperty] private string? searchQuery; [ObservableProperty] private string? searchQuery;
@ -38,10 +37,11 @@ public partial class CheckpointBrowserViewModel : ObservableObject
public IEnumerable<CivitPeriod> AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>(); public IEnumerable<CivitPeriod> AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
public IEnumerable<CivitSortMode> AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>(); public IEnumerable<CivitSortMode> AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public CheckpointBrowserViewModel(ICivitApi civitApi, IDownloadService downloadService) public CheckpointBrowserViewModel(ICivitApi civitApi, IDownloadService downloadService, ISnackbarService snackbarService)
{ {
this.civitApi = civitApi; this.civitApi = civitApi;
this.downloadService = downloadService; this.downloadService = downloadService;
this.snackbarService = snackbarService;
SelectedPeriod = CivitPeriod.Month; SelectedPeriod = CivitPeriod.Month;
SortMode = CivitSortMode.HighestRated; SortMode = CivitSortMode.HighestRated;
@ -76,7 +76,7 @@ public partial class CheckpointBrowserViewModel : ObservableObject
CanGoToPreviousPage = CurrentPageNumber > 1; CanGoToPreviousPage = CurrentPageNumber > 1;
CanGoToNextPage = CurrentPageNumber < TotalPages; CanGoToNextPage = CurrentPageNumber < TotalPages;
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(models.Items.Select( ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(models.Items.Select(
m => new CheckpointBrowserCardViewModel(m, downloadService))); m => new CheckpointBrowserCardViewModel(m, downloadService, snackbarService)));
ShowMainLoadingSpinner = false; ShowMainLoadingSpinner = false;
Logger.Debug($"Found {models.Items.Length} models"); Logger.Debug($"Found {models.Items.Length} models");

Loading…
Cancel
Save