Browse Source

Add additional states that request model reindex

pull/165/head
Ionite 1 year ago
parent
commit
706a5d9cd2
No known key found for this signature in database
  1. 32
      StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs
  2. 2
      StabilityMatrix.Core/Models/Tokens/PromptExtraNetwork.cs
  3. 7
      StabilityMatrix.Core/Models/Tokens/PromptExtraNetworkType.cs
  4. 88
      StabilityMatrix.Core/Services/ModelIndexService.cs
  5. 98
      StabilityMatrix.Core/Services/TrackedDownloadService.cs

32
StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs

@ -1,5 +1,6 @@
using System.Diagnostics;
using System.Text.Json;
using AsyncAwaitBestPractices;
using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Services;
@ -9,36 +10,35 @@ public class CivitPostDownloadContextAction : IContextAction
{
/// <inheritdoc />
public object? Context { get; set; }
public static CivitPostDownloadContextAction FromCivitFile(CivitFile file)
{
return new CivitPostDownloadContextAction
{
Context = file.Hashes.BLAKE3
};
return new CivitPostDownloadContextAction { Context = file.Hashes.BLAKE3 };
}
public void Invoke(ISettingsManager settingsManager)
public void Invoke(ISettingsManager settingsManager, IModelIndexService modelIndexService)
{
var result = Context as string;
if (Context is JsonElement jsonElement)
{
result = jsonElement.GetString();
}
if (result is null)
{
Debug.WriteLine($"Context {Context} is not a string.");
return;
}
Debug.WriteLine($"Adding {result} to installed models.");
settingsManager.Transaction(
s =>
{
s.InstalledModelHashes ??= new HashSet<string>();
s.InstalledModelHashes.Add(result);
});
settingsManager.Transaction(s =>
{
s.InstalledModelHashes ??= new HashSet<string>();
s.InstalledModelHashes.Add(result);
});
// Also request reindex
modelIndexService.BackgroundRefreshIndex();
}
}

2
StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetwork.cs → StabilityMatrix.Core/Models/Tokens/PromptExtraNetwork.cs

@ -1,4 +1,4 @@
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens;
namespace StabilityMatrix.Core.Models.Tokens;
/// <summary>
/// Represents an extra network token in a prompt.

7
StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetworkType.cs → StabilityMatrix.Core/Models/Tokens/PromptExtraNetworkType.cs

@ -1,8 +1,7 @@
using System;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Extensions;
namespace StabilityMatrix.Core.Models.Tokens;
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens;
[Flags]
public enum PromptExtraNetworkType

88
StabilityMatrix.Core/Services/ModelIndexService.cs

@ -1,4 +1,5 @@
using System.Diagnostics;
using AsyncAwaitBestPractices;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Extensions;
@ -14,8 +15,9 @@ public class ModelIndexService : IModelIndexService
private readonly ILiteDbContext liteDbContext;
private readonly ISettingsManager settingsManager;
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; private set; } = new();
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; private set; } =
new();
public ModelIndexService(
ILogger<ModelIndexService> logger,
ILiteDbContext liteDbContext,
@ -41,9 +43,10 @@ public class ModelIndexService : IModelIndexService
return await liteDbContext.LocalModelFiles
.Query()
.Where(m => m.SharedFolderType == type)
.ToArrayAsync().ConfigureAwait(false);
.ToArrayAsync()
.ConfigureAwait(false);
}
/// <inheritdoc />
public async Task RefreshIndex()
{
@ -52,21 +55,20 @@ public class ModelIndexService : IModelIndexService
// Start
var stopwatch = Stopwatch.StartNew();
logger.LogInformation("Refreshing model index...");
using var db
= await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
using var db = await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
var localModelFiles = db.GetCollection<LocalModelFile>("LocalModelFiles")!;
await localModelFiles.DeleteAllAsync().ConfigureAwait(false);
// Record start of actual indexing
var indexStart = stopwatch.Elapsed;
var added = 0;
var newIndex = new Dictionary<SharedFolderType, List<LocalModelFile>>();
foreach (
var file in modelsDir.Info
.EnumerateFiles("*.*", SearchOption.AllDirectories)
@ -78,70 +80,86 @@ public class ModelIndexService : IModelIndexService
{
continue;
}
var relativePath = Path.GetRelativePath(modelsDir, file);
// Get shared folder name
var sharedFolderName = relativePath.Split(Path.DirectorySeparatorChar,
StringSplitOptions.RemoveEmptyEntries)[0];
var sharedFolderName = relativePath.Split(
Path.DirectorySeparatorChar,
StringSplitOptions.RemoveEmptyEntries
)[0];
// Convert to enum
var sharedFolderType = Enum.Parse<SharedFolderType>(sharedFolderName, true);
var localModel = new LocalModelFile
{
RelativePath = relativePath,
SharedFolderType = sharedFolderType,
};
// Try to find a connected model info
var jsonPath = file.Directory!.JoinFile(
new FilePath($"{file.NameWithoutExtension}.cm-info.json"));
new FilePath($"{file.NameWithoutExtension}.cm-info.json")
);
if (jsonPath.Exists)
{
var connectedModelInfo = ConnectedModelInfo.FromJson(
await jsonPath.ReadAllTextAsync().ConfigureAwait(false));
await jsonPath.ReadAllTextAsync().ConfigureAwait(false)
);
localModel.ConnectedModelInfo = connectedModelInfo;
}
// Try to find a preview image
var previewImagePath = LocalModelFile.SupportedImageExtensions
.Select(ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}")
).FirstOrDefault(path => path.Exists);
.Select(
ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}")
)
.FirstOrDefault(path => path.Exists);
if (previewImagePath != null)
{
localModel.PreviewImageRelativePath = Path.GetRelativePath(modelsDir, previewImagePath);
localModel.PreviewImageRelativePath = Path.GetRelativePath(
modelsDir,
previewImagePath
);
}
// Insert into database
await localModelFiles.InsertAsync(localModel).ConfigureAwait(false);
// Add to index
var list = newIndex.GetOrAdd(sharedFolderType);
list.Add(localModel);
added++;
}
// Update index
ModelIndex = newIndex;
// Record end of actual indexing
var indexEnd = stopwatch.Elapsed;
await db.CommitAsync().ConfigureAwait(false);
// End
stopwatch.Stop();
var indexDuration = indexEnd - indexStart;
var dbDuration = stopwatch.Elapsed - indexDuration;
logger.LogInformation("Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)",
added, indexDuration.TotalMilliseconds, dbDuration.TotalMilliseconds);
logger.LogInformation(
"Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)",
added,
indexDuration.TotalMilliseconds,
dbDuration.TotalMilliseconds
);
}
/// <inheritdoc />
public void BackgroundRefreshIndex() { }
public void BackgroundRefreshIndex()
{
RefreshIndex().SafeFireAndForget();
}
}

98
StabilityMatrix.Core/Services/TrackedDownloadService.cs

@ -14,34 +14,39 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
private readonly ILogger<TrackedDownloadService> logger;
private readonly IDownloadService downloadService;
private readonly ISettingsManager settingsManager;
private readonly IModelIndexService modelIndexService;
private readonly ConcurrentDictionary<Guid, (TrackedDownload, FileStream)> downloads = new();
public IEnumerable<TrackedDownload> Downloads => downloads.Values.Select(x => x.Item1);
/// <inheritdoc />
public event EventHandler<TrackedDownload>? DownloadAdded;
public TrackedDownloadService(
ILogger<TrackedDownloadService> logger,
IDownloadService downloadService,
ISettingsManager settingsManager)
IModelIndexService modelIndexService,
ISettingsManager settingsManager
)
{
this.logger = logger;
this.downloadService = downloadService;
this.settingsManager = settingsManager;
this.modelIndexService = modelIndexService;
// Index for in-progress downloads when library dir loaded
settingsManager.RegisterOnLibraryDirSet(path =>
{
var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory);
// Ignore if not exist
if (!downloadsDir.Exists) return;
if (!downloadsDir.Exists)
return;
LoadInProgressDownloads(downloadsDir);
});
}
private void OnDownloadAdded(TrackedDownload download)
{
DownloadAdded?.Invoke(this, download);
@ -55,28 +60,32 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{
// Set download service
download.SetDownloadService(downloadService);
// Create json file
var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory);
downloadsDir.Create();
var jsonFile = downloadsDir.JoinFile($"{download.Id}.json");
var jsonFileStream = jsonFile.Info.Open(FileMode.CreateNew, FileAccess.ReadWrite, FileShare.Read);
var jsonFileStream = jsonFile.Info.Open(
FileMode.CreateNew,
FileAccess.ReadWrite,
FileShare.Read
);
// Serialize to json
var json = JsonSerializer.Serialize(download);
jsonFileStream.Write(Encoding.UTF8.GetBytes(json));
jsonFileStream.Flush();
// Add to dictionary
downloads.TryAdd(download.Id, (download, jsonFileStream));
// Connect to state changed event to update json file
AttachHandlers(download);
logger.LogDebug("Added download {Download}", download.FileName);
OnDownloadAdded(download);
}
/// <summary>
/// Update the json file for the download.
/// </summary>
@ -85,19 +94,19 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
// Serialize to json
var json = JsonSerializer.Serialize(download);
var jsonBytes = Encoding.UTF8.GetBytes(json);
// Write to file
var (_, fs) = downloads[download.Id];
fs.Seek(0, SeekOrigin.Begin);
fs.Write(jsonBytes);
fs.Flush();
}
private void AttachHandlers(TrackedDownload download)
{
download.ProgressStateChanged += TrackedDownload_OnProgressStateChanged;
}
/// <summary>
/// Handler when the download's state changes
/// </summary>
@ -107,10 +116,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{
return;
}
// Update json file
UpdateJsonForDownload(download);
// If the download is completed, remove it from the dictionary and delete the json file
if (e is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled)
{
@ -118,28 +127,30 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{
downloadInfo.Item2.Dispose();
// Delete json file
new DirectoryPath(settingsManager.DownloadsDirectory).JoinFile($"{download.Id}.json").Delete();
new DirectoryPath(settingsManager.DownloadsDirectory)
.JoinFile($"{download.Id}.json")
.Delete();
logger.LogDebug("Removed download {Download}", download.FileName);
}
}
// On successes, run the continuation action
if (e == ProgressState.Success)
{
if (download.ContextAction is CivitPostDownloadContextAction action)
{
logger.LogDebug("Running context action for {Download}", download.FileName);
action.Invoke(settingsManager);
action.Invoke(settingsManager, modelIndexService);
}
}
}
private void LoadInProgressDownloads(DirectoryPath downloadsDir)
{
logger.LogDebug("Indexing in-progress downloads at {DownloadsDir}...", downloadsDir);
var jsonFiles = downloadsDir.Info.EnumerateFiles("*.json", SearchOption.TopDirectoryOnly);
// Add to dictionary, the file name is the guid
foreach (var file in jsonFiles)
{
@ -147,10 +158,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
try
{
var fileStream = file.Open(FileMode.Open, FileAccess.ReadWrite, FileShare.Read);
// Deserialize json and add to dictionary
var download = JsonSerializer.Deserialize<TrackedDownload>(fileStream)!;
// If the download is marked as working, pause it
if (download.ProgressState == ProgressState.Working)
{
@ -159,23 +170,30 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
else if (download.ProgressState != ProgressState.Inactive)
{
// If the download is not inactive, skip it
logger.LogWarning("Skipping download {Download} with state {State}", download.FileName, download.ProgressState);
logger.LogWarning(
"Skipping download {Download} with state {State}",
download.FileName,
download.ProgressState
);
fileStream.Dispose();
// Delete json file
logger.LogDebug("Deleting json file for {Download} with unsupported state", download.FileName);
logger.LogDebug(
"Deleting json file for {Download} with unsupported state",
download.FileName
);
file.Delete();
continue;
}
download.SetDownloadService(downloadService);
downloads.TryAdd(download.Id, (download, fileStream));
AttachHandlers(download);
OnDownloadAdded(download);
logger.LogDebug("Loaded in-progress download {Download}", download.FileName);
}
catch (Exception e)
@ -197,10 +215,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
};
AddDownload(download);
return download;
}
/// <summary>
/// Generate a new temp file name that is unique in the given directory.
/// In format of "Unconfirmed {id}.smdownload"
@ -213,14 +231,14 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
for (var i = 0; i < 10; i++)
{
if (tempFile is {Exists: false})
if (tempFile is { Exists: false })
{
return tempFile.Name;
}
var id = Random.Shared.Next(1000000, 9999999);
tempFile = parentDir.JoinFile($"Unconfirmed {id}.smdownload");
}
throw new Exception("Failed to generate a unique temp file name.");
}
@ -241,7 +259,7 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
}
}
}
GC.SuppressFinalize(this);
}
}

Loading…
Cancel
Save