using System.Net.Http.Headers; using Microsoft.Extensions.Logging; using Polly.Contrib.WaitAndRetry; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models.Progress; namespace StabilityMatrix.Core.Services; [Singleton(typeof(IDownloadService))] public class DownloadService : IDownloadService { private readonly ILogger logger; private readonly IHttpClientFactory httpClientFactory; private readonly ISecretsManager secretsManager; private const int BufferSize = ushort.MaxValue; public DownloadService( ILogger logger, IHttpClientFactory httpClientFactory, ISecretsManager secretsManager ) { this.logger = logger; this.httpClientFactory = httpClientFactory; this.secretsManager = secretsManager; } public async Task DownloadToFileAsync( string downloadUrl, string downloadPath, IProgress? progress = null, string? httpClientName = null, CancellationToken cancellationToken = default ) { using var client = string.IsNullOrWhiteSpace(httpClientName) ? httpClientFactory.CreateClient() : httpClientFactory.CreateClient(httpClientName); client.Timeout = TimeSpan.FromMinutes(10); client.DefaultRequestHeaders.UserAgent.Add( new ProductInfoHeaderValue("StabilityMatrix", "2.0") ); await AddConditionalHeaders(client, new Uri(downloadUrl)).ConfigureAwait(false); await using var file = new FileStream( downloadPath, FileMode.Create, FileAccess.Write, FileShare.None ); long contentLength = 0; var response = await client .GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead, cancellationToken) .ConfigureAwait(false); contentLength = response.Content.Headers.ContentLength ?? 0; var delays = Backoff.DecorrelatedJitterBackoffV2( TimeSpan.FromMilliseconds(50), retryCount: 3 ); foreach (var delay in delays) { if (contentLength > 0) break; logger.LogDebug("Retrying get-headers for content-length"); await Task.Delay(delay, cancellationToken).ConfigureAwait(false); response = await client .GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead, cancellationToken) .ConfigureAwait(false); contentLength = response.Content.Headers.ContentLength ?? 0; } var isIndeterminate = contentLength == 0; await using var stream = await response.Content .ReadAsStreamAsync(cancellationToken) .ConfigureAwait(false); var totalBytesRead = 0L; var buffer = new byte[BufferSize]; while (true) { var bytesRead = await stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); if (bytesRead == 0) break; await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken) .ConfigureAwait(false); totalBytesRead += bytesRead; if (isIndeterminate) { progress?.Report(new ProgressReport(-1, isIndeterminate: true)); } else { progress?.Report( new ProgressReport( current: Convert.ToUInt64(totalBytesRead), total: Convert.ToUInt64(contentLength), message: "Downloading..." ) ); } } await file.FlushAsync(cancellationToken).ConfigureAwait(false); progress?.Report(new ProgressReport(1f, message: "Download complete!")); } /// public async Task ResumeDownloadToFileAsync( string downloadUrl, string downloadPath, long existingFileSize, IProgress? progress = null, string? httpClientName = null, CancellationToken cancellationToken = default ) { using var client = string.IsNullOrWhiteSpace(httpClientName) ? httpClientFactory.CreateClient() : httpClientFactory.CreateClient(httpClientName); using var noRedirectClient = httpClientFactory.CreateClient("DontFollowRedirects"); client.Timeout = TimeSpan.FromMinutes(10); client.DefaultRequestHeaders.UserAgent.Add( new ProductInfoHeaderValue("StabilityMatrix", "2.0") ); await AddConditionalHeaders(client, new Uri(downloadUrl)).ConfigureAwait(false); await AddConditionalHeaders(noRedirectClient, new Uri(downloadUrl)).ConfigureAwait(false); // Create file if it doesn't exist if (!File.Exists(downloadPath)) { logger.LogInformation( "Resume file doesn't exist, creating file {DownloadPath}", downloadPath ); File.Create(downloadPath).Close(); } await using var file = new FileStream( downloadPath, FileMode.Append, FileAccess.Write, FileShare.None ); // Remaining content length long remainingContentLength = 0; // Total of the original content long originalContentLength = 0; using var noRedirectRequest = new HttpRequestMessage(); noRedirectRequest.Method = HttpMethod.Get; noRedirectRequest.RequestUri = new Uri(downloadUrl); noRedirectRequest.Headers.Range = new RangeHeaderValue(existingFileSize, null); HttpResponseMessage? response = null; foreach ( var delay in Backoff.DecorrelatedJitterBackoffV2( TimeSpan.FromMilliseconds(50), retryCount: 4 ) ) { var noRedirectResponse = await noRedirectClient .SendAsync( noRedirectRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken ) .ConfigureAwait(false); if ( (int)noRedirectResponse.StatusCode > 299 && (int)noRedirectResponse.StatusCode < 400 ) { var redirectUrl = noRedirectResponse.Headers.Location?.ToString(); if (redirectUrl != null && redirectUrl.Contains("reason=download-auth")) { throw new UnauthorizedAccessException(); } } using var redirectRequest = new HttpRequestMessage(); redirectRequest.Method = HttpMethod.Get; redirectRequest.RequestUri = new Uri(downloadUrl); redirectRequest.Headers.Range = new RangeHeaderValue(existingFileSize, null); response = await client .SendAsync( redirectRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken ) .ConfigureAwait(false); remainingContentLength = response.Content.Headers.ContentLength ?? 0; originalContentLength = response.Content.Headers.ContentRange?.Length.GetValueOrDefault() ?? 0; if (remainingContentLength > 0) break; logger.LogDebug("Retrying get-headers for content-length"); await Task.Delay(delay, cancellationToken).ConfigureAwait(false); } if (response == null) { throw new ApplicationException("Response is null"); } var isIndeterminate = remainingContentLength == 0; await using var stream = await response.Content .ReadAsStreamAsync(cancellationToken) .ConfigureAwait(false); var totalBytesRead = 0L; var buffer = new byte[BufferSize]; while (true) { cancellationToken.ThrowIfCancellationRequested(); var bytesRead = await stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); if (bytesRead == 0) break; await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken) .ConfigureAwait(false); totalBytesRead += bytesRead; if (isIndeterminate) { progress?.Report(new ProgressReport(-1, isIndeterminate: true)); } else { progress?.Report( new ProgressReport( // Report the current as session current + original start size current: Convert.ToUInt64(totalBytesRead + existingFileSize), // Total as the original total total: Convert.ToUInt64(originalContentLength), message: "Downloading..." ) ); } } await file.FlushAsync(cancellationToken).ConfigureAwait(false); progress?.Report(new ProgressReport(1f, message: "Download complete!")); } /// public async Task GetFileSizeAsync( string downloadUrl, string? httpClientName = null, CancellationToken cancellationToken = default ) { using var client = string.IsNullOrWhiteSpace(httpClientName) ? httpClientFactory.CreateClient() : httpClientFactory.CreateClient(httpClientName); client.Timeout = TimeSpan.FromMinutes(10); client.DefaultRequestHeaders.UserAgent.Add( new ProductInfoHeaderValue("StabilityMatrix", "2.0") ); await AddConditionalHeaders(client, new Uri(downloadUrl)).ConfigureAwait(false); var contentLength = 0L; foreach ( var delay in Backoff.DecorrelatedJitterBackoffV2( TimeSpan.FromMilliseconds(50), retryCount: 3 ) ) { var response = await client .GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead, cancellationToken) .ConfigureAwait(false); contentLength = response.Content.Headers.ContentLength ?? -1; if (contentLength > 0) break; logger.LogDebug("Retrying get-headers for content-length"); await Task.Delay(delay, cancellationToken).ConfigureAwait(false); } return contentLength; } public async Task GetImageStreamFromUrl(string url) { using var client = httpClientFactory.CreateClient(); client.Timeout = TimeSpan.FromSeconds(10); client.DefaultRequestHeaders.UserAgent.Add( new ProductInfoHeaderValue("StabilityMatrix", "2.0") ); await AddConditionalHeaders(client, new Uri(url)).ConfigureAwait(false); try { var response = await client.GetAsync(url).ConfigureAwait(false); return await response.Content.ReadAsStreamAsync().ConfigureAwait(false); } catch (Exception e) { logger.LogError(e, "Failed to get image stream from url {Url}", url); return null; } } /// /// Adds conditional headers to the HttpClient for the given URL /// private async Task AddConditionalHeaders(HttpClient client, Uri url) { // Check if civit download if (url.Host.Equals("civitai.com", StringComparison.OrdinalIgnoreCase)) { // Add auth if we have it if ( await secretsManager.LoadAsync().ConfigureAwait(false) is { CivitApi: { } civitApi } ) { logger.LogTrace( "Adding Civit auth header {Signature} for download {Url}", ObjectHash.GetStringSignature(civitApi.ApiToken), url ); client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue( "Bearer", civitApi.ApiToken ); } } } }