From 0e0ccd0fb75eaaa55934efae7dc962dfba03278a Mon Sep 17 00:00:00 2001 From: Ionite Date: Sun, 20 Aug 2023 19:31:30 -0400 Subject: [PATCH] Add cancellation token for download service --- .../Services/DownloadService.cs | 78 +++++++++++++------ .../Services/IDownloadService.cs | 9 ++- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/StabilityMatrix.Core/Services/DownloadService.cs b/StabilityMatrix.Core/Services/DownloadService.cs index 6bf13e8a..28cbf6e0 100644 --- a/StabilityMatrix.Core/Services/DownloadService.cs +++ b/StabilityMatrix.Core/Services/DownloadService.cs @@ -17,43 +17,66 @@ public class DownloadService : IDownloadService this.httpClientFactory = httpClientFactory; } - public async Task DownloadToFileAsync(string downloadUrl, string downloadPath, - IProgress? progress = null, string? httpClientName = null) + public async Task DownloadToFileAsync( + string downloadUrl, + string downloadPath, + IProgress? progress = null, + string? httpClientName = null, + CancellationToken cancellationToken = default + ) { using var client = string.IsNullOrWhiteSpace(httpClientName) ? httpClientFactory.CreateClient() : httpClientFactory.CreateClient(httpClientName); - + client.Timeout = TimeSpan.FromMinutes(10); - client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("StabilityMatrix", "2.0")); - await using var file = new FileStream(downloadPath, FileMode.Create, FileAccess.Write, FileShare.None); - + client.DefaultRequestHeaders.UserAgent.Add( + new ProductInfoHeaderValue("StabilityMatrix", "2.0") + ); + await using var file = new FileStream( + downloadPath, + FileMode.Create, + FileAccess.Write, + FileShare.None + ); + long contentLength = 0; - var response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); + var response = await client + .GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead, cancellationToken) + .ConfigureAwait(false); contentLength = response.Content.Headers.ContentLength ?? 0; - + var delays = Backoff.DecorrelatedJitterBackoffV2( - TimeSpan.FromMilliseconds(50), retryCount: 3); - + TimeSpan.FromMilliseconds(50), + retryCount: 3 + ); + foreach (var delay in delays) { - if (contentLength > 0) break; + if (contentLength > 0) + break; logger.LogDebug("Retrying get-headers for content-length"); - await Task.Delay(delay); - response = await client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); + response = await client + .GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead, cancellationToken) + .ConfigureAwait(false); contentLength = response.Content.Headers.ContentLength ?? 0; } var isIndeterminate = contentLength == 0; - await using var stream = await response.Content.ReadAsStreamAsync(); + await using var stream = await response.Content + .ReadAsStreamAsync(cancellationToken) + .ConfigureAwait(false); var totalBytesRead = 0L; var buffer = new byte[BufferSize]; while (true) { - var bytesRead = await stream.ReadAsync(buffer); - if (bytesRead == 0) break; - await file.WriteAsync(buffer.AsMemory(0, bytesRead)); + var bytesRead = await stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + break; + await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken) + .ConfigureAwait(false); totalBytesRead += bytesRead; @@ -63,13 +86,18 @@ public class DownloadService : IDownloadService } else { - progress?.Report(new ProgressReport(current: Convert.ToUInt64(totalBytesRead), - total: Convert.ToUInt64(contentLength), message: "Downloading...")); + progress?.Report( + new ProgressReport( + current: Convert.ToUInt64(totalBytesRead), + total: Convert.ToUInt64(contentLength), + message: "Downloading..." + ) + ); } } - await file.FlushAsync(); - + await file.FlushAsync(cancellationToken).ConfigureAwait(false); + progress?.Report(new ProgressReport(1f, message: "Download complete!")); } @@ -77,11 +105,13 @@ public class DownloadService : IDownloadService { using var client = httpClientFactory.CreateClient(); client.Timeout = TimeSpan.FromSeconds(10); - client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("StabilityMatrix", "2.0")); + client.DefaultRequestHeaders.UserAgent.Add( + new ProductInfoHeaderValue("StabilityMatrix", "2.0") + ); try { - var response = await client.GetAsync(url); - return await response.Content.ReadAsStreamAsync(); + var response = await client.GetAsync(url).ConfigureAwait(false); + return await response.Content.ReadAsStreamAsync().ConfigureAwait(false); } catch (Exception e) { diff --git a/StabilityMatrix.Core/Services/IDownloadService.cs b/StabilityMatrix.Core/Services/IDownloadService.cs index 21bddb98..2f6fa6e7 100644 --- a/StabilityMatrix.Core/Services/IDownloadService.cs +++ b/StabilityMatrix.Core/Services/IDownloadService.cs @@ -4,8 +4,13 @@ namespace StabilityMatrix.Core.Services; public interface IDownloadService { - Task DownloadToFileAsync(string downloadUrl, string downloadPath, - IProgress? progress = null, string? httpClientName = null); + Task DownloadToFileAsync( + string downloadUrl, + string downloadPath, + IProgress? progress = null, + string? httpClientName = null, + CancellationToken cancellationToken = default + ); Task GetImageStreamFromUrl(string url); }