Browse Source

Download validation fixes

pull/5/head
Ionite 1 year ago
parent
commit
6d5f47ddaa
No known key found for this signature in database
  1. 4
      StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs
  2. 15
      StabilityMatrix/Services/DownloadService.cs
  3. 42
      StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs
  4. 10
      StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs

4
StabilityMatrix/DesignData/MockCheckpointBrowserViewModel.cs

@ -8,11 +8,11 @@ namespace StabilityMatrix.DesignData;
[DesignOnly(true)]
public class MockCheckpointBrowserViewModel : CheckpointBrowserViewModel
{
public MockCheckpointBrowserViewModel() : base(null!, null!)
public MockCheckpointBrowserViewModel() : base(null!, null!, null!)
{
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>
{
new (null!, null!)
new (null!, null!, null!)
{
CivitModel = new()
{

15
StabilityMatrix/Services/DownloadService.cs

@ -5,6 +5,7 @@ using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Polly.Contrib.WaitAndRetry;
using StabilityMatrix.Models;
namespace StabilityMatrix.Services;
@ -29,26 +30,28 @@ public class DownloadService : IDownloadService
await using var file = new FileStream(downloadLocation, FileMode.Create, FileAccess.Write, FileShare.None);
long contentLength = 0;
var retryCount = 0;
var response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead);
contentLength = response.Content.Headers.ContentLength ?? 0;
while (contentLength == 0 && retryCount++ < 5)
var delays = Backoff.DecorrelatedJitterBackoffV2(
TimeSpan.FromMilliseconds(50), retryCount: 3);
foreach (var delay in delays)
{
if (contentLength > 0) break;
logger.LogDebug("Retrying get-headers for content-length");
Thread.Sleep(50);
await Task.Delay(delay);
response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead);
contentLength = response.Content.Headers.ContentLength ?? 0;
}
var isIndeterminate = contentLength == 0;
await using var stream = await response.Content.ReadAsStreamAsync();
var totalBytesRead = 0L;
var buffer = new byte[bufferSize];
while (true)
{
var buffer = new byte[bufferSize];
var bytesRead = await stream.ReadAsync(buffer);
if (bytesRead == 0) break;
await file.WriteAsync(buffer.AsMemory(0, bytesRead));

42
StabilityMatrix/ViewModels/CheckpointBrowserCardViewModel.cs

@ -1,10 +1,11 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Net.Mime;
using System.Threading.Tasks;
using System.Windows;
using CommunityToolkit.Mvvm.Input;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Helper;
using StabilityMatrix.Models;
using StabilityMatrix.Models.Api;
using StabilityMatrix.Services;
@ -14,14 +15,16 @@ namespace StabilityMatrix.ViewModels;
public partial class CheckpointBrowserCardViewModel : ProgressViewModel
{
private readonly IDownloadService downloadService;
private readonly ISnackbarService snackbarService;
public CivitModel CivitModel { get; init; }
public override Visibility ProgressVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed;
public override Visibility TextVisibility => Value > 0 ? Visibility.Visible : Visibility.Collapsed;
public CheckpointBrowserCardViewModel(CivitModel civitModel, IDownloadService downloadService)
public CheckpointBrowserCardViewModel(CivitModel civitModel, IDownloadService downloadService, ISnackbarService snackbarService)
{
this.downloadService = downloadService;
this.snackbarService = snackbarService;
CivitModel = civitModel;
}
@ -41,18 +44,49 @@ public partial class CheckpointBrowserCardViewModel : ProgressViewModel
Text = "Downloading...";
var latestModelFile = model.ModelVersions[0].Files[0];
var fileExpectedSha256 = latestModelFile.Hashes.SHA256;
var downloadPath = Path.Combine(SharedFolders.SharedFoldersPath,
SharedFolders.SharedFolderTypeToName(model.Type.ToSharedFolderType()), latestModelFile.Name);
var progress = new Progress<ProgressReport>(progress =>
var downloadProgress = new Progress<ProgressReport>(progress =>
{
Value = progress.Percentage;
Text = $"Importing... {progress.Percentage}%";
});
await downloadService.DownloadToFileAsync(latestModelFile.DownloadUrl, downloadPath, progress: progress);
await downloadService.DownloadToFileAsync(latestModelFile.DownloadUrl, downloadPath, progress: downloadProgress);
// When sha256 is available, validate the downloaded file
if (!string.IsNullOrEmpty(fileExpectedSha256))
{
var hashProgress = new Progress<ProgressReport>(progress =>
{
Value = progress.Percentage;
Text = $"Validating... {progress.Percentage}%";
});
var sha256 = await FileHash.GetSha256Async(downloadPath, hashProgress);
if (sha256 != fileExpectedSha256.ToLowerInvariant())
{
Text = "Import Failed!";
DelayedClearProgress(TimeSpan.FromSeconds(800));
await snackbarService.ShowSnackbarAsync(
"This may be caused by network or server issues from CivitAI, please try again in a few minutes.",
"Download failed hash validation", LogLevel.Warning);
return;
}
}
Text = "Import complete!";
Value = 100;
DelayedClearProgress(TimeSpan.FromMilliseconds(800));
}
private void DelayedClearProgress(TimeSpan delay)
{
Task.Delay(delay).ContinueWith(_ =>
{
Text = string.Empty;
Value = 0;
});
}
}

10
StabilityMatrix/ViewModels/CheckpointBrowserViewModel.cs

@ -1,15 +1,13 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using NLog;
using StabilityMatrix.Api;
using StabilityMatrix.Models;
using StabilityMatrix.Helper;
using StabilityMatrix.Models.Api;
using StabilityMatrix.Services;
@ -20,6 +18,7 @@ public partial class CheckpointBrowserViewModel : ObservableObject
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly ICivitApi civitApi;
private readonly IDownloadService downloadService;
private readonly ISnackbarService snackbarService;
private const int MaxModelsPerPage = 14;
[ObservableProperty] private string? searchQuery;
@ -38,10 +37,11 @@ public partial class CheckpointBrowserViewModel : ObservableObject
public IEnumerable<CivitPeriod> AllCivitPeriods => Enum.GetValues(typeof(CivitPeriod)).Cast<CivitPeriod>();
public IEnumerable<CivitSortMode> AllSortModes => Enum.GetValues(typeof(CivitSortMode)).Cast<CivitSortMode>();
public CheckpointBrowserViewModel(ICivitApi civitApi, IDownloadService downloadService)
public CheckpointBrowserViewModel(ICivitApi civitApi, IDownloadService downloadService, ISnackbarService snackbarService)
{
this.civitApi = civitApi;
this.downloadService = downloadService;
this.snackbarService = snackbarService;
SelectedPeriod = CivitPeriod.Month;
SortMode = CivitSortMode.HighestRated;
@ -76,7 +76,7 @@ public partial class CheckpointBrowserViewModel : ObservableObject
CanGoToPreviousPage = CurrentPageNumber > 1;
CanGoToNextPage = CurrentPageNumber < TotalPages;
ModelCards = new ObservableCollection<CheckpointBrowserCardViewModel>(models.Items.Select(
m => new CheckpointBrowserCardViewModel(m, downloadService)));
m => new CheckpointBrowserCardViewModel(m, downloadService, snackbarService)));
ShowMainLoadingSpinner = false;
Logger.Debug($"Found {models.Items.Length} models");

Loading…
Cancel
Save