Browse Source

Throw exceptions when out of free space instead of whatever the fk it does now

pull/341/head
JT 11 months ago
parent
commit
d646db150c
  1. 5
      StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs
  2. 2
      StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs
  3. 40
      StabilityMatrix.Core/Helper/SystemInfo.cs
  4. 138
      StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
  5. 108
      StabilityMatrix.Core/Services/DownloadService.cs
  6. 6
      StabilityMatrix.Core/Services/SettingsManager.cs

5
StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs

@ -130,11 +130,6 @@ public class InferenceTextToImageViewModel : InferenceGenerationViewModelBase, I
_ => Convert.ToUInt64(SeedCardViewModel.Seed)
};
if (!SamplerCardViewModel.IsDenoiseStrengthEnabled)
{
SamplerCardViewModel.DenoiseStrength = 1.0d;
}
BatchSizeCardViewModel.ApplyStep(args);
// Load models

2
StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

@ -115,6 +115,8 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo
e.Builder.Connections.PrimarySize = new Size(Width, Height);
}
DenoiseStrength = IsDenoiseStrengthEnabled ? DenoiseStrength : 1.0d;
// Provide temp values
e.Temp.Conditioning = (
e.Builder.Connections.BaseConditioning!,

40
StabilityMatrix.Core/Helper/SystemInfo.cs

@ -1,9 +1,47 @@
using System.Runtime.InteropServices;
using NLog;
namespace StabilityMatrix.Core.Helper;
public static class SystemInfo
{
[DllImport("UXTheme.dll", SetLastError = true, EntryPoint = "#138")]
public const long Gigabyte = 1024 * 1024 * 1024;
public const long Megabyte = 1024 * 1024;
[DllImport("UXTheme.dll", SetLastError = true, EntryPoint = "#138")]
public static extern bool ShouldUseDarkMode();
[DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)]
[return: MarshalAs(UnmanagedType.Bool)]
private static extern bool GetDiskFreeSpaceEx(
string lpDirectoryName,
out long lpFreeBytesAvailable,
out long lpTotalNumberOfBytes,
out long lpTotalNumberOfFreeBytes
);
public static long? GetDiskFreeSpaceBytes(string path)
{
long? freeBytes = null;
try
{
if (Compat.IsWindows)
{
if (GetDiskFreeSpaceEx(path, out var freeBytesOut, out var _, out var _))
freeBytes = freeBytesOut;
}
if (freeBytes == null)
{
var drive = new DriveInfo(path);
freeBytes = drive.AvailableFreeSpace;
}
}
catch (Exception e)
{
LogManager.GetCurrentClassLogger().Error(e);
}
return freeBytes;
}
}

138
StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs

@ -36,8 +36,7 @@ public abstract class BaseGitPackage : BasePackage
public override string GithubUrl => $"https://github.com/{Author}/{Name}";
public string DownloadLocation =>
Path.Combine(SettingsManager.LibraryDir, "Packages", $"{Name}.zip");
public string DownloadLocation => Path.Combine(SettingsManager.LibraryDir, "Packages", $"{Name}.zip");
protected string GetDownloadUrl(DownloadPackageVersionOptions versionOptions)
{
@ -72,9 +71,7 @@ public abstract class BaseGitPackage : BasePackage
PrerequisiteHelper = prerequisiteHelper;
}
public override async Task<DownloadPackageVersionOptions> GetLatestVersion(
bool includePrerelease = false
)
public override async Task<DownloadPackageVersionOptions> GetLatestVersion(bool includePrerelease = false)
{
if (ShouldIgnoreReleases)
{
@ -87,9 +84,7 @@ public abstract class BaseGitPackage : BasePackage
}
var releases = await GithubApi.GetAllReleases(Author, Name).ConfigureAwait(false);
var latestRelease = includePrerelease
? releases.First()
: releases.First(x => !x.Prerelease);
var latestRelease = includePrerelease ? releases.First() : releases.First(x => !x.Prerelease);
return new DownloadPackageVersionOptions
{
@ -99,11 +94,7 @@ public abstract class BaseGitPackage : BasePackage
};
}
public override Task<IEnumerable<GitCommit>?> GetAllCommits(
string branch,
int page = 1,
int perPage = 10
)
public override Task<IEnumerable<GitCommit>?> GetAllCommits(string branch, int page = 1, int perPage = 10)
{
return GithubApi.GetAllCommits(Author, Name, branch, page, perPage);
}
@ -181,6 +172,14 @@ public abstract class BaseGitPackage : BasePackage
IProgress<ProgressReport>? progress = null
)
{
const long fiveGigs = 5 * SystemInfo.Gigabyte;
if (SystemInfo.GetDiskFreeSpaceBytes(installLocation) < fiveGigs)
{
throw new ApplicationException(
$"Not enough space to download {Name} to {installLocation}, need at least 5GB"
);
}
await PrerequisiteHelper
.RunGit(
new[]
@ -223,10 +222,7 @@ public abstract class BaseGitPackage : BasePackage
zipDirName = entry.FullName;
}
var folderPath = Path.Combine(
installLocation,
entry.FullName.Replace(zipDirName, string.Empty)
);
var folderPath = Path.Combine(installLocation, entry.FullName.Replace(zipDirName, string.Empty));
Directory.CreateDirectory(folderPath);
continue;
}
@ -237,10 +233,7 @@ public abstract class BaseGitPackage : BasePackage
entry.ExtractToFile(destinationPath, true);
progress?.Report(
new ProgressReport(
current: Convert.ToUInt64(currentEntry),
total: Convert.ToUInt64(totalEntries)
)
new ProgressReport(current: Convert.ToUInt64(currentEntry), total: Convert.ToUInt64(totalEntries))
);
}
@ -264,16 +257,12 @@ public abstract class BaseGitPackage : BasePackage
{
if (currentVersion.IsReleaseMode)
{
var latestVersion = await GetLatestVersion(currentVersion.IsPrerelease)
.ConfigureAwait(false);
UpdateAvailable =
latestVersion.VersionTag != currentVersion.InstalledReleaseVersion;
var latestVersion = await GetLatestVersion(currentVersion.IsPrerelease).ConfigureAwait(false);
UpdateAvailable = latestVersion.VersionTag != currentVersion.InstalledReleaseVersion;
return UpdateAvailable;
}
var allCommits = (
await GetAllCommits(currentVersion.InstalledBranch!).ConfigureAwait(false)
)?.ToList();
var allCommits = (await GetAllCommits(currentVersion.InstalledBranch!).ConfigureAwait(false))?.ToList();
if (allCommits == null || !allCommits.Any())
{
@ -305,18 +294,10 @@ public abstract class BaseGitPackage : BasePackage
if (!Directory.Exists(Path.Combine(installedPackage.FullPath!, ".git")))
{
Logger.Info("not a git repo, initializing...");
progress?.Report(
new ProgressReport(-1f, "Initializing git repo", isIndeterminate: true)
);
await PrerequisiteHelper
.RunGit("init", onConsoleOutput, installedPackage.FullPath)
.ConfigureAwait(false);
progress?.Report(new ProgressReport(-1f, "Initializing git repo", isIndeterminate: true));
await PrerequisiteHelper.RunGit("init", onConsoleOutput, installedPackage.FullPath).ConfigureAwait(false);
await PrerequisiteHelper
.RunGit(
new[] { "remote", "add", "origin", GithubUrl },
onConsoleOutput,
installedPackage.FullPath
)
.RunGit(new[] { "remote", "add", "origin", GithubUrl }, onConsoleOutput, installedPackage.FullPath)
.ConfigureAwait(false);
}
@ -328,11 +309,7 @@ public abstract class BaseGitPackage : BasePackage
.ConfigureAwait(false);
progress?.Report(
new ProgressReport(
-1f,
$"Checking out {versionOptions.VersionTag}",
isIndeterminate: true
)
new ProgressReport(-1f, $"Checking out {versionOptions.VersionTag}", isIndeterminate: true)
);
await PrerequisiteHelper
.RunGit(
@ -361,9 +338,7 @@ public abstract class BaseGitPackage : BasePackage
// fetch
progress?.Report(new ProgressReport(-1f, "Fetching data...", isIndeterminate: true));
await PrerequisiteHelper
.RunGit("fetch", onConsoleOutput, installedPackage.FullPath)
.ConfigureAwait(false);
await PrerequisiteHelper.RunGit("fetch", onConsoleOutput, installedPackage.FullPath).ConfigureAwait(false);
if (versionOptions.IsLatest)
{
@ -387,13 +362,7 @@ public abstract class BaseGitPackage : BasePackage
progress?.Report(new ProgressReport(-1f, "Pulling changes...", isIndeterminate: true));
await PrerequisiteHelper
.RunGit(
new[]
{
"pull",
"--autostash",
"origin",
installedPackage.Version.InstalledBranch!
},
new[] { "pull", "--autostash", "origin", installedPackage.Version.InstalledBranch! },
onConsoleOutput,
installedPackage.FullPath!
)
@ -436,51 +405,39 @@ public abstract class BaseGitPackage : BasePackage
};
}
public override Task SetupModelFolders(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
public override Task SetupModelFolders(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod)
{
if (sharedFolderMethod == SharedFolderMethod.Symlink && SharedFolders is { } folders)
{
return StabilityMatrix.Core.Helper.SharedFolders.UpdateLinksForPackage(
folders,
SettingsManager.ModelsDirectory,
installDirectory
);
return StabilityMatrix
.Core
.Helper
.SharedFolders
.UpdateLinksForPackage(folders, SettingsManager.ModelsDirectory, installDirectory);
}
return Task.CompletedTask;
}
public override Task UpdateModelFolders(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
public override Task UpdateModelFolders(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod)
{
if (sharedFolderMethod == SharedFolderMethod.Symlink && SharedFolders is { } sharedFolders)
{
return StabilityMatrix.Core.Helper.SharedFolders.UpdateLinksForPackage(
sharedFolders,
SettingsManager.ModelsDirectory,
installDirectory
);
return StabilityMatrix
.Core
.Helper
.SharedFolders
.UpdateLinksForPackage(sharedFolders, SettingsManager.ModelsDirectory, installDirectory);
}
return Task.CompletedTask;
}
public override Task RemoveModelFolderLinks(
DirectoryPath installDirectory,
SharedFolderMethod sharedFolderMethod
)
public override Task RemoveModelFolderLinks(DirectoryPath installDirectory, SharedFolderMethod sharedFolderMethod)
{
if (SharedFolders is not null && sharedFolderMethod == SharedFolderMethod.Symlink)
{
StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(
SharedFolders,
installDirectory
);
StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(SharedFolders, installDirectory);
}
return Task.CompletedTask;
}
@ -489,12 +446,16 @@ public abstract class BaseGitPackage : BasePackage
{
if (SharedOutputFolders is { } sharedOutputFolders)
{
return StabilityMatrix.Core.Helper.SharedFolders.UpdateLinksForPackage(
sharedOutputFolders,
SettingsManager.ImagesDirectory,
installDirectory,
recursiveDelete: true
);
return StabilityMatrix
.Core
.Helper
.SharedFolders
.UpdateLinksForPackage(
sharedOutputFolders,
SettingsManager.ImagesDirectory,
installDirectory,
recursiveDelete: true
);
}
return Task.CompletedTask;
@ -504,10 +465,7 @@ public abstract class BaseGitPackage : BasePackage
{
if (SharedOutputFolders is { } sharedOutputFolders)
{
StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(
sharedOutputFolders,
installDirectory
);
StabilityMatrix.Core.Helper.SharedFolders.RemoveLinksForPackage(sharedOutputFolders, installDirectory);
}
return Task.CompletedTask;
}

108
StabilityMatrix.Core/Services/DownloadService.cs

@ -39,18 +39,11 @@ public class DownloadService : IDownloadService
: httpClientFactory.CreateClient(httpClientName);
client.Timeout = TimeSpan.FromMinutes(10);
client.DefaultRequestHeaders.UserAgent.Add(
new ProductInfoHeaderValue("StabilityMatrix", "2.0")
);
client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("StabilityMatrix", "2.0"));
await AddConditionalHeaders(client, new Uri(downloadUrl)).ConfigureAwait(false);
await using var file = new FileStream(
downloadPath,
FileMode.Create,
FileAccess.Write,
FileShare.None
);
await using var file = new FileStream(downloadPath, FileMode.Create, FileAccess.Write, FileShare.None);
long contentLength = 0;
@ -59,10 +52,7 @@ public class DownloadService : IDownloadService
.ConfigureAwait(false);
contentLength = response.Content.Headers.ContentLength ?? 0;
var delays = Backoff.DecorrelatedJitterBackoffV2(
TimeSpan.FromMilliseconds(50),
retryCount: 3
);
var delays = Backoff.DecorrelatedJitterBackoffV2(TimeSpan.FromMilliseconds(50), retryCount: 3);
foreach (var delay in delays)
{
@ -77,9 +67,19 @@ public class DownloadService : IDownloadService
}
var isIndeterminate = contentLength == 0;
await using var stream = await response.Content
.ReadAsStreamAsync(cancellationToken)
.ConfigureAwait(false);
if (contentLength > 0)
{
// check free space
var freeSpace = SystemInfo.GetDiskFreeSpaceBytes(Path.GetDirectoryName(downloadPath));
if (freeSpace < contentLength)
{
throw new ApplicationException(
$"Not enough free space to download file. Free: {freeSpace} bytes, Required: {contentLength} bytes"
);
}
}
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var totalBytesRead = 0L;
var buffer = new byte[BufferSize];
while (true)
@ -87,8 +87,7 @@ public class DownloadService : IDownloadService
var bytesRead = await stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
break;
await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken)
.ConfigureAwait(false);
await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false);
totalBytesRead += bytesRead;
@ -130,9 +129,7 @@ public class DownloadService : IDownloadService
using var noRedirectClient = httpClientFactory.CreateClient("DontFollowRedirects");
client.Timeout = TimeSpan.FromMinutes(10);
client.DefaultRequestHeaders.UserAgent.Add(
new ProductInfoHeaderValue("StabilityMatrix", "2.0")
);
client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("StabilityMatrix", "2.0"));
await AddConditionalHeaders(client, new Uri(downloadUrl)).ConfigureAwait(false);
await AddConditionalHeaders(noRedirectClient, new Uri(downloadUrl)).ConfigureAwait(false);
@ -140,19 +137,11 @@ public class DownloadService : IDownloadService
// Create file if it doesn't exist
if (!File.Exists(downloadPath))
{
logger.LogInformation(
"Resume file doesn't exist, creating file {DownloadPath}",
downloadPath
);
logger.LogInformation("Resume file doesn't exist, creating file {DownloadPath}", downloadPath);
File.Create(downloadPath).Close();
}
await using var file = new FileStream(
downloadPath,
FileMode.Append,
FileAccess.Write,
FileShare.None
);
await using var file = new FileStream(downloadPath, FileMode.Append, FileAccess.Write, FileShare.None);
// Remaining content length
long remainingContentLength = 0;
@ -165,24 +154,13 @@ public class DownloadService : IDownloadService
noRedirectRequest.Headers.Range = new RangeHeaderValue(existingFileSize, null);
HttpResponseMessage? response = null;
foreach (
var delay in Backoff.DecorrelatedJitterBackoffV2(
TimeSpan.FromMilliseconds(50),
retryCount: 4
)
)
foreach (var delay in Backoff.DecorrelatedJitterBackoffV2(TimeSpan.FromMilliseconds(50), retryCount: 4))
{
var noRedirectResponse = await noRedirectClient
.SendAsync(
noRedirectRequest,
HttpCompletionOption.ResponseHeadersRead,
cancellationToken
)
.SendAsync(noRedirectRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken)
.ConfigureAwait(false);
if (
(int)noRedirectResponse.StatusCode > 299 && (int)noRedirectResponse.StatusCode < 400
)
if ((int)noRedirectResponse.StatusCode > 299 && (int)noRedirectResponse.StatusCode < 400)
{
var redirectUrl = noRedirectResponse.Headers.Location?.ToString();
if (redirectUrl != null && redirectUrl.Contains("reason=download-auth"))
@ -197,16 +175,11 @@ public class DownloadService : IDownloadService
redirectRequest.Headers.Range = new RangeHeaderValue(existingFileSize, null);
response = await client
.SendAsync(
redirectRequest,
HttpCompletionOption.ResponseHeadersRead,
cancellationToken
)
.SendAsync(redirectRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken)
.ConfigureAwait(false);
remainingContentLength = response.Content.Headers.ContentLength ?? 0;
originalContentLength =
response.Content.Headers.ContentRange?.Length.GetValueOrDefault() ?? 0;
originalContentLength = response.Content.Headers.ContentRange?.Length.GetValueOrDefault() ?? 0;
if (remainingContentLength > 0)
break;
@ -222,9 +195,7 @@ public class DownloadService : IDownloadService
var isIndeterminate = remainingContentLength == 0;
await using var stream = await response.Content
.ReadAsStreamAsync(cancellationToken)
.ConfigureAwait(false);
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var totalBytesRead = 0L;
var buffer = new byte[BufferSize];
while (true)
@ -234,8 +205,7 @@ public class DownloadService : IDownloadService
var bytesRead = await stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
break;
await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken)
.ConfigureAwait(false);
await file.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false);
totalBytesRead += bytesRead;
@ -274,20 +244,13 @@ public class DownloadService : IDownloadService
: httpClientFactory.CreateClient(httpClientName);
client.Timeout = TimeSpan.FromMinutes(10);
client.DefaultRequestHeaders.UserAgent.Add(
new ProductInfoHeaderValue("StabilityMatrix", "2.0")
);
client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("StabilityMatrix", "2.0"));
await AddConditionalHeaders(client, new Uri(downloadUrl)).ConfigureAwait(false);
var contentLength = 0L;
foreach (
var delay in Backoff.DecorrelatedJitterBackoffV2(
TimeSpan.FromMilliseconds(50),
retryCount: 3
)
)
foreach (var delay in Backoff.DecorrelatedJitterBackoffV2(TimeSpan.FromMilliseconds(50), retryCount: 3))
{
var response = await client
.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead, cancellationToken)
@ -308,9 +271,7 @@ public class DownloadService : IDownloadService
{
using var client = httpClientFactory.CreateClient();
client.Timeout = TimeSpan.FromSeconds(10);
client.DefaultRequestHeaders.UserAgent.Add(
new ProductInfoHeaderValue("StabilityMatrix", "2.0")
);
client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("StabilityMatrix", "2.0"));
await AddConditionalHeaders(client, new Uri(url)).ConfigureAwait(false);
try
{
@ -333,19 +294,14 @@ public class DownloadService : IDownloadService
if (url.Host.Equals("civitai.com", StringComparison.OrdinalIgnoreCase))
{
// Add auth if we have it
if (
await secretsManager.LoadAsync().ConfigureAwait(false) is { CivitApi: { } civitApi }
)
if (await secretsManager.LoadAsync().ConfigureAwait(false) is { CivitApi: { } civitApi })
{
logger.LogTrace(
"Adding Civit auth header {Signature} for download {Url}",
ObjectHash.GetStringSignature(civitApi.ApiToken),
url
);
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue(
"Bearer",
civitApi.ApiToken
);
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", civitApi.ApiToken);
}
}
}

6
StabilityMatrix.Core/Services/SettingsManager.cs

@ -650,6 +650,12 @@ public class SettingsManager : ISettingsManager
if (!isLoaded)
return;
if (SystemInfo.GetDiskFreeSpaceBytes(SettingsPath) < 1024 * 1024)
{
Logger.Warn("Not enough disk space to save settings");
return;
}
var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(Settings, SettingsSerializerContext.Default.Settings);
File.WriteAllBytes(SettingsPath, jsonBytes);

Loading…
Cancel
Save