Browse Source

Add jwt token refreshing, Improved UserSecrets deserialization error handling

pull/324/head
Ionite 1 year ago
parent
commit
d0979941b5
No known key found for this signature in database
  1. 16
      StabilityMatrix.Avalonia/App.axaml.cs
  2. 63
      StabilityMatrix.Avalonia/Services/AccountsService.cs
  3. 15
      StabilityMatrix.Avalonia/ViewModels/Settings/AccountSettingsViewModel.cs
  4. 6
      StabilityMatrix.Core/Api/ILykosAuthApi.cs
  5. 7
      StabilityMatrix.Core/Api/ITokenProvider.cs
  6. 52
      StabilityMatrix.Core/Api/LykosAuthTokenProvider.cs
  7. 70
      StabilityMatrix.Core/Api/TokenAuthHeaderHandler.cs
  8. 3
      StabilityMatrix.Core/Models/Api/Lykos/PostLoginRefreshRequest.cs
  9. 70
      StabilityMatrix.Core/Models/GlobalEncryptedSerializer.cs
  10. 8
      StabilityMatrix.Core/Models/Secrets.cs
  11. 21
      StabilityMatrix.Core/Services/ISecretsManager.cs
  12. 64
      StabilityMatrix.Core/Services/SecretsManager.cs
  13. 1
      StabilityMatrix.Core/StabilityMatrix.Core.csproj

16
StabilityMatrix.Avalonia/App.axaml.cs

@ -529,20 +529,20 @@ public sealed class App : Application
})
.AddPolicyHandler(retryPolicy);
var lykosAuthRefitSettings = new RefitSettings
{
ContentSerializer = new SystemTextJsonContentSerializer(jsonSerializerOptions),
AuthorizationHeaderValueGetter = (_, ct) =>
Task.FromResult(GlobalUserSecrets.LoadFromFile().LykosAccessToken ?? "")
};
services
.AddRefitClient<ILykosAuthApi>(lykosAuthRefitSettings)
.AddRefitClient<ILykosAuthApi>(defaultRefitSettings)
.ConfigureHttpClient(c =>
{
c.BaseAddress = new Uri("https://stableauthentication.azurewebsites.net");
c.Timeout = TimeSpan.FromSeconds(15);
})
.AddPolicyHandler(retryPolicy);
.AddPolicyHandler(retryPolicy)
.AddHttpMessageHandler(
serviceProvider =>
new TokenAuthHeaderHandler(
serviceProvider.GetRequiredService<LykosAuthTokenProvider>()
)
);
// Add Refit client managers
services

63
StabilityMatrix.Avalonia/Services/AccountsService.cs

@ -1,4 +1,5 @@
using System;
using System.Net;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Octokit;
@ -6,6 +7,7 @@ using StabilityMatrix.Core.Api;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Api.Lykos;
using StabilityMatrix.Core.Services;
using ApiException = Refit.ApiException;
namespace StabilityMatrix.Avalonia.Services;
@ -14,6 +16,7 @@ namespace StabilityMatrix.Avalonia.Services;
public class AccountsService : IAccountsService
{
private readonly ILogger<AccountsService> logger;
private readonly ISecretsManager secretsManager;
private readonly ILykosAuthApi lykosAuthApi;
private readonly ICivitTRPCApi civitTRPCApi;
@ -24,11 +27,13 @@ public class AccountsService : IAccountsService
public AccountsService(
ILogger<AccountsService> logger,
ISecretsManager secretsManager,
ILykosAuthApi lykosAuthApi,
ICivitTRPCApi civitTRPCApi
)
{
this.logger = logger;
this.secretsManager = secretsManager;
this.lykosAuthApi = lykosAuthApi;
this.civitTRPCApi = civitTRPCApi;
@ -38,55 +43,48 @@ public class AccountsService : IAccountsService
public async Task LykosLoginAsync(string email, string password)
{
var secrets = GlobalUserSecrets.LoadFromFile();
var secrets = await secretsManager.SafeLoadAsync();
var tokens = await lykosAuthApi.PostLogin(new PostLoginRequest(email, password));
secrets.LykosAccessToken = tokens.AccessToken;
secrets.LykosRefreshToken = tokens.RefreshToken;
secrets.SaveToFile();
await secretsManager.SaveAsync(secrets with { LykosAccount = tokens });
await RefreshAsync();
}
public async Task LykosSignupAsync(string email, string password, string username)
{
var secrets = GlobalUserSecrets.LoadFromFile();
var secrets = await secretsManager.SafeLoadAsync();
var tokens = await lykosAuthApi.PostAccount(
new PostAccountRequest(email, password, password, username)
);
secrets.LykosAccessToken = tokens.AccessToken;
secrets.LykosRefreshToken = tokens.RefreshToken;
secrets.SaveToFile();
secrets = secrets with { LykosAccount = tokens };
await RefreshAsync();
await secretsManager.SaveAsync(secrets);
await RefreshLykosAsync(secrets);
}
public Task LykosLogoutAsync()
public async Task LykosLogoutAsync()
{
var secrets = GlobalUserSecrets.LoadFromFile();
secrets.LykosAccessToken = null;
secrets.LykosRefreshToken = null;
secrets.SaveToFile();
var secrets = await secretsManager.SafeLoadAsync();
await secretsManager.SaveAsync(secrets with { LykosAccount = null });
OnLykosAccountStatusUpdate(LykosAccountStatusUpdateEventArgs.Disconnected);
return Task.CompletedTask;
}
public async Task RefreshAsync()
{
var secrets = GlobalUserSecrets.LoadFromFile();
var secrets = await secretsManager.SafeLoadAsync();
await RefreshLykosAsync(secrets);
}
private async Task RefreshLykosAsync(GlobalUserSecrets secrets)
private async Task RefreshLykosAsync(Secrets secrets)
{
if (secrets.LykosAccessToken is { } accessToken && !string.IsNullOrEmpty(accessToken))
if (secrets.LykosAccount is not null)
{
try
{
@ -100,17 +98,36 @@ public class AccountsService : IAccountsService
}
catch (OperationCanceledException)
{
logger.LogWarning("Timed out");
logger.LogWarning("Timed out while fetching Lykos Auth user info");
}
catch (ApiException e)
{
logger.LogWarning(e, "Failed to get user info from Lykos");
if (e.StatusCode is HttpStatusCode.Unauthorized) { }
else
{
logger.LogWarning(e, "Failed to get user info from Lykos");
}
}
}
OnLykosAccountStatusUpdate(LykosAccountStatusUpdateEventArgs.Disconnected);
}
private void OnLykosAccountStatusUpdate(LykosAccountStatusUpdateEventArgs e) =>
private void OnLykosAccountStatusUpdate(LykosAccountStatusUpdateEventArgs e)
{
if (!e.IsConnected && LykosStatus?.IsConnected == true)
{
logger.LogInformation("Lykos account disconnected");
}
else if (e.IsConnected && LykosStatus?.IsConnected == false)
{
logger.LogInformation(
"Lykos account connected: {Id} ({Username})",
e.User?.Id,
e.User?.Account.Name
);
}
LykosAccountStatusUpdate?.Invoke(this, e);
}
}

15
StabilityMatrix.Avalonia/ViewModels/Settings/AccountSettingsViewModel.cs

@ -2,7 +2,9 @@
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using FluentAvalonia.UI.Controls;
@ -74,20 +76,25 @@ public partial class AccountSettingsViewModel : PageViewModelBase
accountsService.LykosAccountStatusUpdate += (_, args) =>
{
IsLykosConnected = args.IsConnected;
LykosUser = args.User;
Dispatcher.UIThread.Post(() =>
{
IsLykosConnected = args.IsConnected;
LykosUser = args.User;
});
};
}
/// <inheritdoc />
public override async Task OnLoadedAsync()
public override void OnLoaded()
{
base.OnLoaded();
if (Design.IsDesignMode)
{
return;
}
await accountsService.RefreshAsync();
accountsService.RefreshAsync().SafeFireAndForget();
}
[RelayCommand]

6
StabilityMatrix.Core/Api/ILykosAuthApi.cs

@ -25,4 +25,10 @@ public interface ILykosAuthApi
[Body] PostLoginRequest request,
CancellationToken cancellationToken = default
);
[Post("/api/Login/Refresh")]
Task<LykosAccountTokens> PostLoginRefresh(
[Body] PostLoginRefreshRequest request,
CancellationToken cancellationToken = default
);
}

7
StabilityMatrix.Core/Api/ITokenProvider.cs

@ -0,0 +1,7 @@
namespace StabilityMatrix.Core.Api;
public interface ITokenProvider
{
Task<string> GetAccessTokenAsync();
Task<(string AccessToken, string RefreshToken)> RefreshTokensAsync();
}

52
StabilityMatrix.Core/Api/LykosAuthTokenProvider.cs

@ -0,0 +1,52 @@
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Lykos;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Core.Api;
[Singleton]
public class LykosAuthTokenProvider : ITokenProvider
{
private readonly ISecretsManager secretsManager;
private readonly Lazy<ILykosAuthApi> lazyLykosAuthApi;
public LykosAuthTokenProvider(
Lazy<ILykosAuthApi> lazyLykosAuthApi,
ISecretsManager secretsManager
)
{
// Lazy as instantiating requires the current class to be instantiated.
this.lazyLykosAuthApi = lazyLykosAuthApi;
this.secretsManager = secretsManager;
}
/// <inheritdoc />
public async Task<string> GetAccessTokenAsync()
{
var secrets = await secretsManager.SafeLoadAsync().ConfigureAwait(false);
return secrets.LykosAccount?.AccessToken ?? "";
}
/// <inheritdoc />
public async Task<(string AccessToken, string RefreshToken)> RefreshTokensAsync()
{
var secrets = await secretsManager.SafeLoadAsync().ConfigureAwait(false);
if (string.IsNullOrWhiteSpace(secrets.LykosAccount?.RefreshToken))
{
throw new InvalidOperationException("No refresh token found");
}
var lykosAuthApi = lazyLykosAuthApi.Value;
var newTokens = await lykosAuthApi
.PostLoginRefresh(new PostLoginRefreshRequest(secrets.LykosAccount.RefreshToken))
.ConfigureAwait(false);
secrets = secrets with { LykosAccount = newTokens };
await secretsManager.SaveAsync(secrets).ConfigureAwait(false);
return (newTokens.AccessToken, newTokens.RefreshToken);
}
}

70
StabilityMatrix.Core/Api/TokenAuthHeaderHandler.cs

@ -0,0 +1,70 @@
using System.Net;
using System.Net.Http.Headers;
using NLog;
using Polly;
using Polly.Retry;
using StabilityMatrix.Core.Helper;
namespace StabilityMatrix.Core.Api;
public class TokenAuthHeaderHandler : DelegatingHandler
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
private readonly AsyncRetryPolicy<HttpResponseMessage> policy;
private readonly ITokenProvider tokenProvider;
public TokenAuthHeaderHandler(ITokenProvider tokenProvider)
{
this.tokenProvider = tokenProvider;
policy = Policy
.HandleResult<HttpResponseMessage>(
r =>
r.StatusCode is HttpStatusCode.Unauthorized or HttpStatusCode.Forbidden
&& r.RequestMessage?.Headers.Authorization
is { Parameter: "Bearer", Scheme: not null }
)
.RetryAsync(
async (result, _) =>
{
var oldToken = ObjectHash.GetStringSignature(
await tokenProvider.GetAccessTokenAsync().ConfigureAwait(false)
);
Logger.Info(
"Refreshing access token for status ({StatusCode}) {Message}",
result.Result.StatusCode,
result.Exception.Message
);
var (newToken, _) = await tokenProvider
.RefreshTokensAsync()
.ConfigureAwait(false);
Logger.Info(
"Access token refreshed: {OldToken} -> {NewToken}",
ObjectHash.GetStringSignature(oldToken),
ObjectHash.GetStringSignature(newToken)
);
}
);
// InnerHandler must be left as null when using DI, but must be assigned a value when
// using RestService.For<IMyApi>
// InnerHandler = new HttpClientHandler();
}
protected override Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request,
CancellationToken cancellationToken
)
{
return policy.ExecuteAsync(async () =>
{
var accessToken = await tokenProvider.GetAccessTokenAsync().ConfigureAwait(false);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
return await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
});
}
}

3
StabilityMatrix.Core/Models/Api/Lykos/PostLoginRefreshRequest.cs

@ -0,0 +1,3 @@
namespace StabilityMatrix.Core.Models.Api.Lykos;
public record PostLoginRefreshRequest(string RefreshToken);

70
StabilityMatrix.Core/Models/GlobalUserSecrets.cs → StabilityMatrix.Core/Models/GlobalEncryptedSerializer.cs

@ -2,30 +2,44 @@
using System.Security;
using System.Security.Cryptography;
using System.Text.Json;
using System.Text.Json.Serialization;
using DeviceId;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Core.Models;
internal record struct KeyInfo(byte[] Key, byte[] Salt, int Iterations);
/// <summary>
/// Global instance of user secrets.
/// Stored in %APPDATA%\StabilityMatrix\user-secrets.data
/// Encrypted MessagePack Serializer that uses a global key derived from the computer's SID.
/// Header contains additional random entropy as a salt that is used in decryption.
/// </summary>
public class GlobalUserSecrets
public static class GlobalEncryptedSerializer
{
private const int KeySize = 32;
private const int Iterations = 300;
private const int SaltSize = 16;
[JsonIgnore]
public static FilePath File { get; } = GlobalConfig.HomeDir + "user-secrets.data";
public static T Deserialize<T>(ReadOnlySpan<byte> data)
{
// Get salt from start of file
var salt = data[..SaltSize].ToArray();
// Get encrypted json from rest of file
var encryptedJson = data[SaltSize..];
var json = DecryptBytes(encryptedJson, salt);
return JsonSerializer.Deserialize<T>(json)
?? throw new Exception("Deserialize returned null");
}
public string? LykosAccessToken { get; set; }
public static byte[] Serialize<T>(T obj)
{
var json = JsonSerializer.SerializeToUtf8Bytes(obj);
var (encrypted, salt) = EncryptBytes(json);
// Prepend salt to encrypted json
var fileBytes = salt.Concat(encrypted).ToArray();
public string? LykosRefreshToken { get; set; }
return fileBytes;
}
private static string? GetComputerSid()
{
@ -96,7 +110,7 @@ public class GlobalUserSecrets
passwordByteArray[i] = Marshal.ReadByte(ptr, i);
}
using var rfc2898 = new Rfc2898DeriveBytes(passwordByteArray, salt, iterations);
using var rfc2898 = new Rfc2898DeriveBytes(passwordByteArray, salt, iterations, HashAlgorithmName.SHA512);
return rfc2898.GetBytes(keyLength);
}
finally
@ -125,7 +139,7 @@ public class GlobalUserSecrets
return (transform.TransformFinalBlock(data, 0, data.Length), keyInfo.Salt);
}
private static byte[] DecryptBytes(IReadOnlyCollection<byte> encryptedData, byte[] salt)
private static byte[] DecryptBytes(ReadOnlySpan<byte> encryptedData, byte[] salt)
{
var key = DeriveKey(GetComputerKeyPhrase(), salt, Iterations, KeySize);
@ -136,38 +150,6 @@ public class GlobalUserSecrets
aes.Mode = CipherMode.CBC;
var transform = aes.CreateDecryptor();
return transform.TransformFinalBlock(encryptedData.ToArray(), 0, encryptedData.Count);
}
public void SaveToFile()
{
var json = JsonSerializer.SerializeToUtf8Bytes(this);
var (encrypted, salt) = EncryptBytes(json);
// Prepend salt to encrypted json
var fileBytes = salt.Concat(encrypted).ToArray();
File.WriteAllBytes(fileBytes);
}
public static GlobalUserSecrets LoadFromFile()
{
File.Info.Refresh();
if (!File.Exists)
{
return new GlobalUserSecrets();
}
var fileBytes = File.ReadAllBytes();
// Get salt from start of file
var salt = fileBytes.AsSpan(0, SaltSize).ToArray();
// Get encrypted json from rest of file
var encryptedJson = fileBytes.AsSpan(SaltSize).ToArray();
var json = DecryptBytes(encryptedJson, salt);
return JsonSerializer.Deserialize<GlobalUserSecrets>(json)
?? throw new Exception("Deserialized user secrets is null");
return transform.TransformFinalBlock(encryptedData.ToArray(), 0, encryptedData.Length);
}
}

8
StabilityMatrix.Core/Models/Secrets.cs

@ -0,0 +1,8 @@
using StabilityMatrix.Core.Models.Api.Lykos;
namespace StabilityMatrix.Core.Models;
public readonly record struct Secrets
{
public LykosAccountTokens? LykosAccount { get; init; }
}

21
StabilityMatrix.Core/Services/ISecretsManager.cs

@ -0,0 +1,21 @@
using StabilityMatrix.Core.Models;
namespace StabilityMatrix.Core.Services;
/// <summary>
/// Interface for managing secure settings and tokens.
/// </summary>
public interface ISecretsManager
{
/// <summary>
/// Load and return the secrets.
/// </summary>
Task<Secrets> LoadAsync();
/// <summary>
/// Load and return the secrets, or save and return a new instance on error.
/// </summary>
Task<Secrets> SafeLoadAsync();
Task SaveAsync(Secrets secrets);
}

64
StabilityMatrix.Core/Services/SecretsManager.cs

@ -0,0 +1,64 @@
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Core.Services;
/// <summary>
/// Default implementation of <see cref="ISecretsManager"/>.
/// Data is encrypted at rest in %APPDATA%\StabilityMatrix\user-secrets.data
/// </summary>
[Singleton(typeof(ISecretsManager))]
public class SecretsManager : ISecretsManager
{
private readonly ILogger<SecretsManager> logger;
private static FilePath GlobalFile => GlobalConfig.HomeDir.JoinFile("user-secrets.data");
public SecretsManager(ILogger<SecretsManager> logger)
{
this.logger = logger;
}
/// <inheritdoc />
public async Task<Secrets> LoadAsync()
{
if (!GlobalFile.Exists)
{
return new Secrets();
}
var fileBytes = await GlobalFile.ReadAllBytesAsync().ConfigureAwait(false);
return GlobalEncryptedSerializer.Deserialize<Secrets>(fileBytes);
}
/// <inheritdoc />
public async Task<Secrets> SafeLoadAsync()
{
try
{
return await LoadAsync().ConfigureAwait(false);
}
catch (Exception e)
{
logger.LogWarning(
e,
"Failed to load secrets ({ExcType}), saving new instance",
e.GetType().Name
);
var secrets = new Secrets();
await SaveAsync(secrets).ConfigureAwait(false);
return secrets;
}
}
/// <inheritdoc />
public Task SaveAsync(Secrets secrets)
{
var fileBytes = GlobalEncryptedSerializer.Serialize(secrets);
return GlobalFile.WriteAllBytesAsync(fileBytes);
}
}

1
StabilityMatrix.Core/StabilityMatrix.Core.csproj

@ -35,6 +35,7 @@
<PackageReference Include="Octokit" Version="8.1.1" />
<PackageReference Include="OneOf" Version="3.0.263" />
<PackageReference Include="OneOf.SourceGenerator" Version="3.0.263" />
<PackageReference Include="Polly" Version="8.0.0" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Refit" Version="7.0.0" />

Loading…
Cancel
Save