Browse Source

Merge pull request #226 from ionite34/models-index

pull/109/head
Ionite 1 year ago committed by GitHub
parent
commit
fd384908ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      StabilityMatrix.Avalonia/App.axaml.cs
  2. 1
      StabilityMatrix.Avalonia/DesignData/DesignData.cs
  3. 2
      StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs
  4. 27
      StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs
  5. 13
      StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs
  6. 12
      StabilityMatrix.Avalonia/Views/SettingsPage.axaml
  7. 1
      StabilityMatrix.Core/Database/ILiteDbContext.cs
  8. 1
      StabilityMatrix.Core/Database/LiteDbContext.cs
  9. 47
      StabilityMatrix.Core/Models/Database/LocalModelFile.cs
  10. 5
      StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs
  11. 22
      StabilityMatrix.Core/Services/IModelIndexService.cs
  12. 135
      StabilityMatrix.Core/Services/ModelIndexService.cs

1
StabilityMatrix.Avalonia/App.axaml.cs

@ -346,6 +346,7 @@ public sealed class App : Application
services.AddSingleton<IPyRunner, PyRunner>();
services.AddSingleton<IUpdateHelper, UpdateHelper>();
services.AddSingleton<INavigationService, NavigationService>();
services.AddSingleton<IModelIndexService, ModelIndexService>();
services.AddSingleton<ITrackedDownloadService, TrackedDownloadService>();
services.AddSingleton<IDisposable>(provider =>

1
StabilityMatrix.Avalonia/DesignData/DesignData.cs

@ -89,6 +89,7 @@ public static class DesignData
.AddSingleton<IDownloadService, MockDownloadService>()
.AddSingleton<IHttpClientFactory, MockHttpClientFactory>()
.AddSingleton<IDiscordRichPresenceService, MockDiscordRichPresenceService>()
.AddSingleton<IModelIndexService, MockModelIndexService>()
.AddSingleton<ITrackedDownloadService, MockTrackedDownloadService>();
// Placeholder services that nobody should need during design time

2
StabilityMatrix.Avalonia/DesignData/MockLiteDbContext.cs

@ -14,6 +14,8 @@ public class MockLiteDbContext : ILiteDbContext
public ILiteCollectionAsync<CivitModel> CivitModels => throw new NotImplementedException();
public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions => throw new NotImplementedException();
public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache => throw new NotImplementedException();
public ILiteCollectionAsync<LocalModelFile> LocalModelFiles => throw new NotImplementedException();
public Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(string hashBlake3)
{
return Task.FromResult<(CivitModel?, CivitModelVersion?)>((null, null));

27
StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs

@ -0,0 +1,27 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.DesignData;
public class MockModelIndexService : IModelIndexService
{
/// <inheritdoc />
public Task RefreshIndex()
{
return Task.CompletedTask;
}
/// <inheritdoc />
public Task<IReadOnlyList<LocalModelFile>> GetModelsOfType(SharedFolderType type)
{
return Task.FromResult<IReadOnlyList<LocalModelFile>>(new List<LocalModelFile>());
}
/// <inheritdoc />
public void BackgroundRefreshIndex()
{
}
}

13
StabilityMatrix.Avalonia/ViewModels/SettingsViewModel.cs

@ -50,6 +50,7 @@ public partial class SettingsViewModel : PageViewModelBase
private readonly IPyRunner pyRunner;
private readonly ServiceManager<ViewModelBase> dialogFactory;
private readonly ITrackedDownloadService trackedDownloadService;
private readonly IModelIndexService modelIndexService;
public SharedState SharedState { get; }
@ -114,7 +115,8 @@ public partial class SettingsViewModel : PageViewModelBase
IPyRunner pyRunner,
ServiceManager<ViewModelBase> dialogFactory,
SharedState sharedState,
ITrackedDownloadService trackedDownloadService)
ITrackedDownloadService trackedDownloadService,
IModelIndexService modelIndexService)
{
this.notificationService = notificationService;
this.settingsManager = settingsManager;
@ -122,7 +124,8 @@ public partial class SettingsViewModel : PageViewModelBase
this.pyRunner = pyRunner;
this.dialogFactory = dialogFactory;
this.trackedDownloadService = trackedDownloadService;
this.modelIndexService = modelIndexService;
SharedState = sharedState;
SelectedTheme = settingsManager.Settings.Theme ?? AvailableThemes[1];
@ -501,6 +504,12 @@ public partial class SettingsViewModel : PageViewModelBase
throw new OperationCanceledException("Example Message");
}
[RelayCommand]
private async Task DebugRefreshModelsIndex()
{
await modelIndexService.RefreshIndex();
}
[RelayCommand]
private async Task DebugTrackedDownload()
{

12
StabilityMatrix.Avalonia/Views/SettingsPage.axaml

@ -307,6 +307,18 @@
Content="Add Tracked Download" />
</ui:SettingsExpanderItem.Footer>
</ui:SettingsExpanderItem>
<ui:SettingsExpanderItem
Margin="4,0,4,4"
Content="Refresh Models Index"
IconSource="SyncFolder">
<ui:SettingsExpanderItem.Footer>
<Button
Margin="0,8"
Command="{Binding DebugRefreshModelsIndexCommand}"
Content="Refresh Index" />
</ui:SettingsExpanderItem.Footer>
</ui:SettingsExpanderItem>
</ui:SettingsExpander>
</Grid>

1
StabilityMatrix.Core/Database/ILiteDbContext.cs

@ -11,6 +11,7 @@ public interface ILiteDbContext : IDisposable
ILiteCollectionAsync<CivitModel> CivitModels { get; }
ILiteCollectionAsync<CivitModelVersion> CivitModelVersions { get; }
ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache { get; }
ILiteCollectionAsync<LocalModelFile> LocalModelFiles { get; }
Task<(CivitModel?, CivitModelVersion?)> FindCivitModelFromFileHashAsync(string hashBlake3);

1
StabilityMatrix.Core/Database/LiteDbContext.cs

@ -27,6 +27,7 @@ public class LiteDbContext : ILiteDbContext
public ILiteCollectionAsync<CivitModelVersion> CivitModelVersions => Database.GetCollection<CivitModelVersion>("CivitModelVersions");
public ILiteCollectionAsync<CivitModelQueryCacheEntry> CivitModelQueryCache => Database.GetCollection<CivitModelQueryCacheEntry>("CivitModelQueryCache");
public ILiteCollectionAsync<GithubCacheEntry> GithubCache => Database.GetCollection<GithubCacheEntry>("GithubCache");
public ILiteCollectionAsync<LocalModelFile> LocalModelFiles => Database.GetCollection<LocalModelFile>("LocalModelFiles");
public LiteDbContext(
ILogger<LiteDbContext> logger,

47
StabilityMatrix.Core/Models/Database/LocalModelFile.cs

@ -0,0 +1,47 @@
using LiteDB;
namespace StabilityMatrix.Core.Models.Database;
/// <summary>
/// Represents a locally indexed model file.
/// </summary>
public class LocalModelFile
{
/// <summary>
/// Relative path to the file from the root model directory.
/// </summary>
[BsonId]
public required string RelativePath { get; set; }
/// <summary>
/// Type of the model file.
/// </summary>
public required SharedFolderType SharedFolderType { get; set; }
/// <summary>
/// Optional connected model information.
/// </summary>
public ConnectedModelInfo? ConnectedModelInfo { get; set; }
/// <summary>
/// Optional preview image relative path.
/// </summary>
public string? PreviewImageRelativePath { get; set; }
public string GetFullPath(string rootModelDirectory)
{
return Path.Combine(rootModelDirectory, RelativePath);
}
public string? GetPreviewImageFullPath(string rootModelDirectory)
{
return PreviewImageRelativePath == null ? null
: Path.Combine(rootModelDirectory, PreviewImageRelativePath);
}
public static readonly HashSet<string> SupportedCheckpointExtensions =
new() { ".safetensors", ".pt", ".ckpt", ".pth", ".bin" };
public static readonly HashSet<string> SupportedImageExtensions =
new() { ".png", ".jpg", ".jpeg" };
public static readonly HashSet<string> SupportedMetadataExtensions = new() { ".json" };
}

5
StabilityMatrix.Core/Models/FileInterfaces/FilePath.cs

@ -58,6 +58,11 @@ public class FilePath : FileSystemPath, IPathObject
{
}
public FilePath(FileInfo fileInfo) : base(fileInfo.FullName)
{
_info = fileInfo;
}
public FilePath(FileSystemPath path) : base(path)
{
}

22
StabilityMatrix.Core/Services/IModelIndexService.cs

@ -0,0 +1,22 @@
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
namespace StabilityMatrix.Core.Services;
public interface IModelIndexService
{
/// <summary>
/// Refreshes the local model file index.
/// </summary>
Task RefreshIndex();
/// <summary>
/// Get all models of the specified type from the existing index.
/// </summary>
Task<IReadOnlyList<LocalModelFile>> GetModelsOfType(SharedFolderType type);
/// <summary>
/// Starts a background task to refresh the local model file index.
/// </summary>
void BackgroundRefreshIndex();
}

135
StabilityMatrix.Core/Services/ModelIndexService.cs

@ -0,0 +1,135 @@
using System.Diagnostics;
using Microsoft.Extensions.Logging;
using StabilityMatrix.Core.Database;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.Database;
using StabilityMatrix.Core.Models.FileInterfaces;
namespace StabilityMatrix.Core.Services;
public class ModelIndexService : IModelIndexService
{
private readonly ILogger<ModelIndexService> logger;
private readonly ILiteDbContext liteDbContext;
private readonly ISettingsManager settingsManager;
public ModelIndexService(
ILogger<ModelIndexService> logger,
ILiteDbContext liteDbContext,
ISettingsManager settingsManager
)
{
this.logger = logger;
this.liteDbContext = liteDbContext;
this.settingsManager = settingsManager;
}
/// <summary>
/// Deletes all entries in the local model file index.
/// </summary>
private async Task ClearIndex()
{
await liteDbContext.LocalModelFiles.DeleteAllAsync().ConfigureAwait(false);
}
/// <inheritdoc />
public async Task<IReadOnlyList<LocalModelFile>> GetModelsOfType(SharedFolderType type)
{
return await liteDbContext.LocalModelFiles
.Query()
.Where(m => m.SharedFolderType == type)
.ToArrayAsync().ConfigureAwait(false);
}
/// <inheritdoc />
public async Task RefreshIndex()
{
var modelsDir = new DirectoryPath(settingsManager.ModelsDirectory);
// Start
var stopwatch = Stopwatch.StartNew();
logger.LogInformation("Refreshing model index...");
using var db
= await liteDbContext.Database.BeginTransactionAsync().ConfigureAwait(false);
var localModelFiles = db.GetCollection<LocalModelFile>("LocalModelFiles")!;
await localModelFiles.DeleteAllAsync().ConfigureAwait(false);
// Record start of actual indexing
var indexStart = stopwatch.Elapsed;
var added = 0;
foreach (
var file in modelsDir.Info
.EnumerateFiles("*.*", SearchOption.AllDirectories)
.Select(info => new FilePath(info))
)
{
// Skip if not supported extension
if (!LocalModelFile.SupportedCheckpointExtensions.Contains(file.Info.Extension))
{
continue;
}
var relativePath = Path.GetRelativePath(modelsDir, file);
// Get shared folder name
var sharedFolderName = relativePath.Split(Path.DirectorySeparatorChar,
StringSplitOptions.RemoveEmptyEntries)[0];
// Convert to enum
var sharedFolderType = Enum.Parse<SharedFolderType>(sharedFolderName, true);
var localModel = new LocalModelFile
{
RelativePath = relativePath,
SharedFolderType = sharedFolderType,
};
// Try to find a connected model info
var jsonPath = file.Directory!.JoinFile(
new FilePath(file.NameWithoutExtension, ".cm-info.json"));
if (jsonPath.Exists)
{
var connectedModelInfo = ConnectedModelInfo.FromJson(
await jsonPath.ReadAllTextAsync().ConfigureAwait(false));
localModel.ConnectedModelInfo = connectedModelInfo;
}
// Try to find a preview image
var previewImagePath = LocalModelFile.SupportedImageExtensions
.Select(ext => file.Directory!.JoinFile($"{file.NameWithoutExtension}.preview{ext}")
).FirstOrDefault(path => path.Exists);
if (previewImagePath != null)
{
localModel.PreviewImageRelativePath = Path.GetRelativePath(modelsDir, previewImagePath);
}
// Insert into database
await localModelFiles.InsertAsync(localModel).ConfigureAwait(false);
added++;
}
// Record end of actual indexing
var indexEnd = stopwatch.Elapsed;
await db.CommitAsync().ConfigureAwait(false);
// End
stopwatch.Stop();
var indexDuration = indexEnd - indexStart;
var dbDuration = stopwatch.Elapsed - indexDuration;
logger.LogInformation("Model index refreshed with {Entries} entries, took {IndexDuration:F1}ms ({DbDuration:F1}ms db)",
added, indexDuration.TotalMilliseconds, dbDuration.TotalMilliseconds);
}
/// <inheritdoc />
public void BackgroundRefreshIndex() { }
}
Loading…
Cancel
Save