Browse Source

Add additional states that request model reindex

pull/165/head
Ionite 1 year ago
parent
commit
706a5d9cd2
No known key found for this signature in database
  1. 32
      StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs
  2. 2
      StabilityMatrix.Core/Models/Tokens/PromptExtraNetwork.cs
  3. 7
      StabilityMatrix.Core/Models/Tokens/PromptExtraNetworkType.cs
  4. 88
      StabilityMatrix.Core/Services/ModelIndexService.cs
  5. 98
      StabilityMatrix.Core/Services/TrackedDownloadService.cs

32
StabilityMatrix.Core/Models/CivitPostDownloadContextAction.cs

@ -1,5 +1,6 @@
using System.Diagnostics; using System.Diagnostics;
using System.Text.Json; using System.Text.Json;
using AsyncAwaitBestPractices;
using StabilityMatrix.Core.Models.Api; using StabilityMatrix.Core.Models.Api;
using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Services;
@ -9,36 +10,35 @@ public class CivitPostDownloadContextAction : IContextAction
{ {
/// <inheritdoc /> /// <inheritdoc />
public object? Context { get; set; } public object? Context { get; set; }
public static CivitPostDownloadContextAction FromCivitFile(CivitFile file) public static CivitPostDownloadContextAction FromCivitFile(CivitFile file)
{ {
return new CivitPostDownloadContextAction return new CivitPostDownloadContextAction { Context = file.Hashes.BLAKE3 };
{
Context = file.Hashes.BLAKE3
};
} }
public void Invoke(ISettingsManager settingsManager) public void Invoke(ISettingsManager settingsManager, IModelIndexService modelIndexService)
{ {
var result = Context as string; var result = Context as string;
if (Context is JsonElement jsonElement) if (Context is JsonElement jsonElement)
{ {
result = jsonElement.GetString(); result = jsonElement.GetString();
} }
if (result is null) if (result is null)
{ {
Debug.WriteLine($"Context {Context} is not a string."); Debug.WriteLine($"Context {Context} is not a string.");
return; return;
} }
Debug.WriteLine($"Adding {result} to installed models."); Debug.WriteLine($"Adding {result} to installed models.");
settingsManager.Transaction( settingsManager.Transaction(s =>
s => {
{ s.InstalledModelHashes ??= new HashSet<string>();
s.InstalledModelHashes ??= new HashSet<string>(); s.InstalledModelHashes.Add(result);
s.InstalledModelHashes.Add(result); });
});
// Also request reindex
modelIndexService.BackgroundRefreshIndex();
} }
} }

2
StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetwork.cs → StabilityMatrix.Core/Models/Tokens/PromptExtraNetwork.cs

@ -1,4 +1,4 @@
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens; namespace StabilityMatrix.Core.Models.Tokens;
/// <summary> /// <summary>
/// Represents an extra network token in a prompt. /// Represents an extra network token in a prompt.

7
StabilityMatrix.Avalonia/Models/Inference/Tokens/PromptExtraNetworkType.cs → StabilityMatrix.Core/Models/Tokens/PromptExtraNetworkType.cs

@ -1,8 +1,7 @@
using System; using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models; namespace StabilityMatrix.Core.Models.Tokens;
namespace StabilityMatrix.Avalonia.Models.Inference.Tokens;
[Flags] [Flags]
public enum PromptExtraNetworkType public enum PromptExtraNetworkType

88
StabilityMatrix.Core/Services/ModelIndexService.cs

@ -1,4 +1,5 @@
using System.Diagnostics; using System.Diagnostics;
using AsyncAwaitBestPractices;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Database; using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Extensions;
@ -14,8 +15,9 @@ public class ModelIndexService : IModelIndexService
private readonly ILiteDbContext liteDbContext; private readonly ILiteDbContext liteDbContext;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; private set; } = new(); public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; private set; } =
new();
public ModelIndexService( public ModelIndexService(
ILogger<ModelIndexService> logger, ILogger<ModelIndexService> logger,
ILiteDbContext liteDbContext, ILiteDbContext liteDbContext,
@ -41,9 +43,10 @@ public class ModelIndexService : IModelIndexService
return await liteDbContext.LocalModelFiles return await liteDbContext.LocalModelFiles
.Query() .Query()
.Where(m => m.SharedFolderType == type) .Where(m => m.SharedFolderType == type)
.ToArrayAsync().ConfigureAwait(false); .ToArrayAsync()
.ConfigureAwait(false);
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task RefreshIndex() public async Task RefreshIndex()
{ {
@ -52,21 +55,20 @@ public class ModelIndexService : IModelIndexService
// Start // Start
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
logger.LogInformation("Refreshing model index..."); logger.LogInformation("Refreshing model index...");
using var db using var db = await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
= await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
var localModelFiles = db.GetCollection<LocalModelFile>("LocalModelFiles")!; var localModelFiles = db.GetCollection<LocalModelFile>("LocalModelFiles")!;
await localModelFiles.DeleteAllAsync().ConfigureAwait(false); await localModelFiles.DeleteAllAsync().ConfigureAwait(false);
// Record start of actual indexing // Record start of actual indexing
var indexStart = stopwatch.Elapsed; var indexStart = stopwatch.Elapsed;
var added = 0; var added = 0;
var newIndex = new Dictionary<SharedFolderType, List<LocalModelFile>>(); var newIndex = new Dictionary<SharedFolderType, List<LocalModelFile>>();
foreach ( foreach (
var file in modelsDir.Info var file in modelsDir.Info
.EnumerateFiles("*.*", SearchOption.AllDirectories) .EnumerateFiles("*.*", SearchOption.AllDirectories)
@ -78,70 +80,86 @@ public class ModelIndexService : IModelIndexService
{ {
continue; continue;
} }
var relativePath = Path.GetRelativePath(modelsDir, file); var relativePath = Path.GetRelativePath(modelsDir, file);
// Get shared folder name // Get shared folder name
var sharedFolderName = relativePath.Split(Path.DirectorySeparatorChar, var sharedFolderName = relativePath.Split(
StringSplitOptions.RemoveEmptyEntries)[0]; Path.DirectorySeparatorChar,
StringSplitOptions.RemoveEmptyEntries
)[0];
// Convert to enum // Convert to enum
var sharedFolderType = Enum.Parse<SharedFolderType>(sharedFolderName, true); var sharedFolderType = Enum.Parse<SharedFolderType>(sharedFolderName, true);
var localModel = new LocalModelFile var localModel = new LocalModelFile
{ {
RelativePath = relativePath, RelativePath = relativePath,
SharedFolderType = sharedFolderType, SharedFolderType = sharedFolderType,
}; };
// Try to find a connected model info // Try to find a connected model info
var jsonPath = file.Directory!.JoinFile( var jsonPath = file.Directory!.JoinFile(
new FilePath($"{file.NameWithoutExtension}.cm-info.json")); new FilePath($"{file.NameWithoutExtension}.cm-info.json")
);
if (jsonPath.Exists) if (jsonPath.Exists)
{ {
var connectedModelInfo = ConnectedModelInfo.FromJson( var connectedModelInfo = ConnectedModelInfo.FromJson(
await jsonPath.ReadAllTextAsync().ConfigureAwait(false)); await jsonPath.ReadAllTextAsync().ConfigureAwait(false)
);
localModel.ConnectedModelInfo = connectedModelInfo; localModel.ConnectedModelInfo = connectedModelInfo;
} }
// Try to find a preview image // Try to find a preview image
var previewImagePath = LocalModelFile.SupportedImageExtensions var previewImagePath = LocalModelFile.SupportedImageExtensions
.Select(ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}") .Select(
).FirstOrDefault(path => path.Exists); ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}")
)
.FirstOrDefault(path => path.Exists);
if (previewImagePath != null) if (previewImagePath != null)
{ {
localModel.PreviewImageRelativePath = Path.GetRelativePath(modelsDir, previewImagePath); localModel.PreviewImageRelativePath = Path.GetRelativePath(
modelsDir,
previewImagePath
);
} }
// Insert into database // Insert into database
await localModelFiles.InsertAsync(localModel).ConfigureAwait(false); await localModelFiles.InsertAsync(localModel).ConfigureAwait(false);
// Add to index // Add to index
var list = newIndex.GetOrAdd(sharedFolderType); var list = newIndex.GetOrAdd(sharedFolderType);
list.Add(localModel); list.Add(localModel);
added++; added++;
} }
// Update index // Update index
ModelIndex = newIndex; ModelIndex = newIndex;
// Record end of actual indexing // Record end of actual indexing
var indexEnd = stopwatch.Elapsed; var indexEnd = stopwatch.Elapsed;
await db.CommitAsync().ConfigureAwait(false); await db.CommitAsync().ConfigureAwait(false);
// End // End
stopwatch.Stop(); stopwatch.Stop();
var indexDuration = indexEnd - indexStart; var indexDuration = indexEnd - indexStart;
var dbDuration = stopwatch.Elapsed - indexDuration; var dbDuration = stopwatch.Elapsed - indexDuration;
logger.LogInformation("Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)", logger.LogInformation(
added, indexDuration.TotalMilliseconds, dbDuration.TotalMilliseconds); "Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)",
added,
indexDuration.TotalMilliseconds,
dbDuration.TotalMilliseconds
);
} }
/// <inheritdoc /> /// <inheritdoc />
public void BackgroundRefreshIndex() { } public void BackgroundRefreshIndex()
{
RefreshIndex().SafeFireAndForget();
}
} }

98
StabilityMatrix.Core/Services/TrackedDownloadService.cs

@ -14,34 +14,39 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
private readonly ILogger<TrackedDownloadService> logger; private readonly ILogger<TrackedDownloadService> logger;
private readonly IDownloadService downloadService; private readonly IDownloadService downloadService;
private readonly ISettingsManager settingsManager; private readonly ISettingsManager settingsManager;
private readonly IModelIndexService modelIndexService;
private readonly ConcurrentDictionary<Guid, (TrackedDownload, FileStream)> downloads = new(); private readonly ConcurrentDictionary<Guid, (TrackedDownload, FileStream)> downloads = new();
public IEnumerable<TrackedDownload> Downloads => downloads.Values.Select(x => x.Item1); public IEnumerable<TrackedDownload> Downloads => downloads.Values.Select(x => x.Item1);
/// <inheritdoc /> /// <inheritdoc />
public event EventHandler<TrackedDownload>? DownloadAdded; public event EventHandler<TrackedDownload>? DownloadAdded;
public TrackedDownloadService( public TrackedDownloadService(
ILogger<TrackedDownloadService> logger, ILogger<TrackedDownloadService> logger,
IDownloadService downloadService, IDownloadService downloadService,
ISettingsManager settingsManager) IModelIndexService modelIndexService,
ISettingsManager settingsManager
)
{ {
this.logger = logger; this.logger = logger;
this.downloadService = downloadService; this.downloadService = downloadService;
this.settingsManager = settingsManager; this.settingsManager = settingsManager;
this.modelIndexService = modelIndexService;
// Index for in-progress downloads when library dir loaded // Index for in-progress downloads when library dir loaded
settingsManager.RegisterOnLibraryDirSet(path => settingsManager.RegisterOnLibraryDirSet(path =>
{ {
var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory); var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory);
// Ignore if not exist // Ignore if not exist
if (!downloadsDir.Exists) return; if (!downloadsDir.Exists)
return;
LoadInProgressDownloads(downloadsDir); LoadInProgressDownloads(downloadsDir);
}); });
} }
private void OnDownloadAdded(TrackedDownload download) private void OnDownloadAdded(TrackedDownload download)
{ {
DownloadAdded?.Invoke(this, download); DownloadAdded?.Invoke(this, download);
@ -55,28 +60,32 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
// Set download service // Set download service
download.SetDownloadService(downloadService); download.SetDownloadService(downloadService);
// Create json file // Create json file
var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory); var downloadsDir = new DirectoryPath(settingsManager.DownloadsDirectory);
downloadsDir.Create(); downloadsDir.Create();
var jsonFile = downloadsDir.JoinFile($"{download.Id}.json"); var jsonFile = downloadsDir.JoinFile($"{download.Id}.json");
var jsonFileStream = jsonFile.Info.Open(FileMode.CreateNew, FileAccess.ReadWrite, FileShare.Read); var jsonFileStream = jsonFile.Info.Open(
FileMode.CreateNew,
FileAccess.ReadWrite,
FileShare.Read
);
// Serialize to json // Serialize to json
var json = JsonSerializer.Serialize(download); var json = JsonSerializer.Serialize(download);
jsonFileStream.Write(Encoding.UTF8.GetBytes(json)); jsonFileStream.Write(Encoding.UTF8.GetBytes(json));
jsonFileStream.Flush(); jsonFileStream.Flush();
// Add to dictionary // Add to dictionary
downloads.TryAdd(download.Id, (download, jsonFileStream)); downloads.TryAdd(download.Id, (download, jsonFileStream));
// Connect to state changed event to update json file // Connect to state changed event to update json file
AttachHandlers(download); AttachHandlers(download);
logger.LogDebug("Added download {Download}", download.FileName); logger.LogDebug("Added download {Download}", download.FileName);
OnDownloadAdded(download); OnDownloadAdded(download);
} }
/// <summary> /// <summary>
/// Update the json file for the download. /// Update the json file for the download.
/// </summary> /// </summary>
@ -85,19 +94,19 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
// Serialize to json // Serialize to json
var json = JsonSerializer.Serialize(download); var json = JsonSerializer.Serialize(download);
var jsonBytes = Encoding.UTF8.GetBytes(json); var jsonBytes = Encoding.UTF8.GetBytes(json);
// Write to file // Write to file
var (_, fs) = downloads[download.Id]; var (_, fs) = downloads[download.Id];
fs.Seek(0, SeekOrigin.Begin); fs.Seek(0, SeekOrigin.Begin);
fs.Write(jsonBytes); fs.Write(jsonBytes);
fs.Flush(); fs.Flush();
} }
private void AttachHandlers(TrackedDownload download) private void AttachHandlers(TrackedDownload download)
{ {
download.ProgressStateChanged += TrackedDownload_OnProgressStateChanged; download.ProgressStateChanged += TrackedDownload_OnProgressStateChanged;
} }
/// <summary> /// <summary>
/// Handler when the download's state changes /// Handler when the download's state changes
/// </summary> /// </summary>
@ -107,10 +116,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
return; return;
} }
// Update json file // Update json file
UpdateJsonForDownload(download); UpdateJsonForDownload(download);
// If the download is completed, remove it from the dictionary and delete the json file // If the download is completed, remove it from the dictionary and delete the json file
if (e is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled) if (e is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled)
{ {
@ -118,28 +127,30 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
{ {
downloadInfo.Item2.Dispose(); downloadInfo.Item2.Dispose();
// Delete json file // Delete json file
new DirectoryPath(settingsManager.DownloadsDirectory).JoinFile($"{download.Id}.json").Delete(); new DirectoryPath(settingsManager.DownloadsDirectory)
.JoinFile($"{download.Id}.json")
.Delete();
logger.LogDebug("Removed download {Download}", download.FileName); logger.LogDebug("Removed download {Download}", download.FileName);
} }
} }
// On successes, run the continuation action // On successes, run the continuation action
if (e == ProgressState.Success) if (e == ProgressState.Success)
{ {
if (download.ContextAction is CivitPostDownloadContextAction action) if (download.ContextAction is CivitPostDownloadContextAction action)
{ {
logger.LogDebug("Running context action for {Download}", download.FileName); logger.LogDebug("Running context action for {Download}", download.FileName);
action.Invoke(settingsManager); action.Invoke(settingsManager, modelIndexService);
} }
} }
} }
private void LoadInProgressDownloads(DirectoryPath downloadsDir) private void LoadInProgressDownloads(DirectoryPath downloadsDir)
{ {
logger.LogDebug("Indexing in-progress downloads at {DownloadsDir}...", downloadsDir); logger.LogDebug("Indexing in-progress downloads at {DownloadsDir}...", downloadsDir);
var jsonFiles = downloadsDir.Info.EnumerateFiles("*.json", SearchOption.TopDirectoryOnly); var jsonFiles = downloadsDir.Info.EnumerateFiles("*.json", SearchOption.TopDirectoryOnly);
// Add to dictionary, the file name is the guid // Add to dictionary, the file name is the guid
foreach (var file in jsonFiles) foreach (var file in jsonFiles)
{ {
@ -147,10 +158,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
try try
{ {
var fileStream = file.Open(FileMode.Open, FileAccess.ReadWrite, FileShare.Read); var fileStream = file.Open(FileMode.Open, FileAccess.ReadWrite, FileShare.Read);
// Deserialize json and add to dictionary // Deserialize json and add to dictionary
var download = JsonSerializer.Deserialize<TrackedDownload>(fileStream)!; var download = JsonSerializer.Deserialize<TrackedDownload>(fileStream)!;
// If the download is marked as working, pause it // If the download is marked as working, pause it
if (download.ProgressState == ProgressState.Working) if (download.ProgressState == ProgressState.Working)
{ {
@ -159,23 +170,30 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
else if (download.ProgressState != ProgressState.Inactive) else if (download.ProgressState != ProgressState.Inactive)
{ {
// If the download is not inactive, skip it // If the download is not inactive, skip it
logger.LogWarning("Skipping download {Download} with state {State}", download.FileName, download.ProgressState); logger.LogWarning(
"Skipping download {Download} with state {State}",
download.FileName,
download.ProgressState
);
fileStream.Dispose(); fileStream.Dispose();
// Delete json file // Delete json file
logger.LogDebug("Deleting json file for {Download} with unsupported state", download.FileName); logger.LogDebug(
"Deleting json file for {Download} with unsupported state",
download.FileName
);
file.Delete(); file.Delete();
continue; continue;
} }
download.SetDownloadService(downloadService); download.SetDownloadService(downloadService);
downloads.TryAdd(download.Id, (download, fileStream)); downloads.TryAdd(download.Id, (download, fileStream));
AttachHandlers(download); AttachHandlers(download);
OnDownloadAdded(download); OnDownloadAdded(download);
logger.LogDebug("Loaded in-progress download {Download}", download.FileName); logger.LogDebug("Loaded in-progress download {Download}", download.FileName);
} }
catch (Exception e) catch (Exception e)
@ -197,10 +215,10 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
}; };
AddDownload(download); AddDownload(download);
return download; return download;
} }
/// <summary> /// <summary>
/// Generate a new temp file name that is unique in the given directory. /// Generate a new temp file name that is unique in the given directory.
/// In format of "Unconfirmed {id}.smdownload" /// In format of "Unconfirmed {id}.smdownload"
@ -213,14 +231,14 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
for (var i = 0; i < 10; i++) for (var i = 0; i < 10; i++)
{ {
if (tempFile is {Exists: false}) if (tempFile is { Exists: false })
{ {
return tempFile.Name; return tempFile.Name;
} }
var id = Random.Shared.Next(1000000, 9999999); var id = Random.Shared.Next(1000000, 9999999);
tempFile = parentDir.JoinFile($"Unconfirmed {id}.smdownload"); tempFile = parentDir.JoinFile($"Unconfirmed {id}.smdownload");
} }
throw new Exception("Failed to generate a unique temp file name."); throw new Exception("Failed to generate a unique temp file name.");
} }
@ -241,7 +259,7 @@ public class TrackedDownloadService : ITrackedDownloadService, IDisposable
} }
} }
} }
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
} }

Loading…
Cancel
Save