You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
6.7 KiB
205 lines
6.7 KiB
using System; |
|
using System.Collections.Concurrent; |
|
using System.Collections.Generic; |
|
using System.IO; |
|
using System.Linq; |
|
using System.Reactive.Linq; |
|
using System.Text.Json; |
|
using System.Text.Json.Serialization; |
|
using System.Threading.Tasks; |
|
using Avalonia; |
|
using Avalonia.Controls; |
|
using Avalonia.Controls.Notifications; |
|
using Avalonia.Data; |
|
using Avalonia.Input; |
|
using Avalonia.Markup.Xaml.MarkupExtensions; |
|
using Avalonia.Media; |
|
using Avalonia.Threading; |
|
using CommunityToolkit.Mvvm.ComponentModel; |
|
using CommunityToolkit.Mvvm.Input; |
|
using DynamicData; |
|
using DynamicData.Binding; |
|
using FluentAvalonia.UI.Controls; |
|
using StabilityMatrix.Avalonia.Languages; |
|
using StabilityMatrix.Avalonia.Models.HuggingFace; |
|
using StabilityMatrix.Avalonia.Services; |
|
using StabilityMatrix.Avalonia.ViewModels.Base; |
|
using StabilityMatrix.Avalonia.ViewModels.HuggingFacePage; |
|
using StabilityMatrix.Core.Attributes; |
|
using StabilityMatrix.Core.Extensions; |
|
using StabilityMatrix.Core.Models; |
|
using StabilityMatrix.Core.Models.FileInterfaces; |
|
using StabilityMatrix.Core.Models.Progress; |
|
using StabilityMatrix.Core.Services; |
|
|
|
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; |
|
|
|
[View(typeof(Views.HuggingFacePage))] |
|
[Singleton] |
|
public partial class HuggingFacePageViewModel : TabViewModelBase |
|
{ |
|
private readonly ITrackedDownloadService trackedDownloadService; |
|
private readonly ISettingsManager settingsManager; |
|
private readonly INotificationService notificationService; |
|
|
|
public SourceCache<HuggingfaceItem, string> ItemsCache { get; } = |
|
new(i => i.RepositoryPath + i.ModelName); |
|
|
|
public IObservableCollection<CategoryViewModel> Categories { get; set; } = |
|
new ObservableCollectionExtended<CategoryViewModel>(); |
|
|
|
public string DownloadPercentText => |
|
Math.Abs(TotalProgress.Percentage - 100f) < 0.001f |
|
? "Download Complete" |
|
: $"Downloading {TotalProgress.Percentage:0}%"; |
|
|
|
[ObservableProperty] |
|
private int numSelected; |
|
|
|
private ConcurrentDictionary<Guid, ProgressReport> progressReports = new(); |
|
|
|
[ObservableProperty] |
|
[NotifyPropertyChangedFor(nameof(DownloadPercentText))] |
|
private ProgressReport totalProgress; |
|
|
|
private readonly DispatcherTimer progressTimer = new() { Interval = TimeSpan.FromMilliseconds(100) }; |
|
|
|
public HuggingFacePageViewModel( |
|
ITrackedDownloadService trackedDownloadService, |
|
ISettingsManager settingsManager, |
|
INotificationService notificationService |
|
) |
|
{ |
|
this.trackedDownloadService = trackedDownloadService; |
|
this.settingsManager = settingsManager; |
|
this.notificationService = notificationService; |
|
|
|
ItemsCache |
|
.Connect() |
|
.DeferUntilLoaded() |
|
.Group(i => i.ModelCategory) |
|
.Transform( |
|
g => |
|
new CategoryViewModel( |
|
g.Cache.Items, |
|
Design.IsDesignMode ? string.Empty : settingsManager.ModelsDirectory |
|
) |
|
{ |
|
Title = g.Key.GetDescription() ?? g.Key.ToString() |
|
} |
|
) |
|
.SortBy(vm => vm.Title) |
|
.Bind(Categories) |
|
.WhenAnyPropertyChanged() |
|
.Subscribe(_ => NumSelected = Categories.Sum(c => c.NumSelected)); |
|
|
|
progressTimer.Tick += (_, _) => |
|
{ |
|
var currentSum = 0ul; |
|
var totalSum = 0ul; |
|
foreach (var progress in progressReports.Values) |
|
{ |
|
currentSum += progress.Current ?? 0; |
|
totalSum += progress.Total ?? 0; |
|
} |
|
|
|
TotalProgress = new ProgressReport(current: currentSum, total: totalSum); |
|
}; |
|
} |
|
|
|
public override void OnLoaded() |
|
{ |
|
if (ItemsCache.Count > 0) |
|
return; |
|
|
|
using var reader = new StreamReader(Assets.HfPackagesJson.Open()); |
|
var packages = |
|
JsonSerializer.Deserialize<IReadOnlyList<HuggingfaceItem>>( |
|
reader.ReadToEnd(), |
|
new JsonSerializerOptions { Converters = { new JsonStringEnumConverter() } } |
|
) ?? throw new InvalidOperationException("Failed to read hf-packages.json"); |
|
|
|
ItemsCache.EditDiff(packages, (a, b) => a.RepositoryPath == b.RepositoryPath); |
|
} |
|
|
|
public void ClearSelection() |
|
{ |
|
foreach (var category in Categories) |
|
{ |
|
category.IsChecked = true; |
|
category.IsChecked = false; |
|
} |
|
} |
|
|
|
public void SelectAll() |
|
{ |
|
foreach (var category in Categories) |
|
{ |
|
category.IsChecked = true; |
|
} |
|
} |
|
|
|
[RelayCommand] |
|
private async Task ImportSelected() |
|
{ |
|
var selected = Categories.SelectMany(c => c.Items).Where(i => i.IsSelected).ToArray(); |
|
|
|
foreach (var viewModel in selected) |
|
{ |
|
foreach (var file in viewModel.Item.Files) |
|
{ |
|
var url = |
|
$"https://huggingface.co/{viewModel.Item.RepositoryPath}/resolve/main/{file}?download=true"; |
|
var sharedFolderType = viewModel.Item.ModelCategory.ConvertTo<SharedFolderType>(); |
|
var downloadPath = new FilePath( |
|
Path.Combine( |
|
Design.IsDesignMode ? string.Empty : settingsManager.ModelsDirectory, |
|
sharedFolderType.ToString(), |
|
viewModel.Item.Subfolder ?? string.Empty, |
|
file |
|
) |
|
); |
|
Directory.CreateDirectory(downloadPath.Directory); |
|
var download = trackedDownloadService.NewDownload(url, downloadPath); |
|
download.ProgressUpdate += DownloadOnProgressUpdate; |
|
download.Start(); |
|
|
|
await Task.Delay(Random.Shared.Next(50, 100)); |
|
} |
|
} |
|
progressTimer.Start(); |
|
} |
|
|
|
private void DownloadOnProgressUpdate(object? sender, ProgressReport e) |
|
{ |
|
if (sender is not TrackedDownload trackedDownload) |
|
return; |
|
|
|
progressReports[trackedDownload.Id] = e; |
|
} |
|
|
|
partial void OnTotalProgressChanged(ProgressReport value) |
|
{ |
|
if (Math.Abs(value.Percentage - 100) < 0.001f) |
|
{ |
|
notificationService.Show( |
|
"Download complete", |
|
"All selected models have been downloaded.", |
|
NotificationType.Success |
|
); |
|
progressTimer.Stop(); |
|
DelayedClearProgress(TimeSpan.FromSeconds(1.5)); |
|
} |
|
} |
|
|
|
private void DelayedClearProgress(TimeSpan delay) |
|
{ |
|
Task.Delay(delay) |
|
.ContinueWith(_ => |
|
{ |
|
TotalProgress = new ProgressReport(0, 0); |
|
}); |
|
} |
|
|
|
public override string Header => Resources.Label_HuggingFace; |
|
}
|
|
|