You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
206 lines
6.9 KiB
206 lines
6.9 KiB
12 months ago
|
using System;
|
||
|
using System.Collections.Concurrent;
|
||
|
using System.Collections.Generic;
|
||
|
using System.IO;
|
||
|
using System.Linq;
|
||
12 months ago
|
using System.Reactive.Linq;
|
||
12 months ago
|
using System.Text.Json;
|
||
|
using System.Text.Json.Serialization;
|
||
|
using System.Threading.Tasks;
|
||
12 months ago
|
using Avalonia;
|
||
12 months ago
|
using Avalonia.Controls.Notifications;
|
||
12 months ago
|
using Avalonia.Data;
|
||
|
using Avalonia.Input;
|
||
|
using Avalonia.Markup.Xaml.MarkupExtensions;
|
||
|
using Avalonia.Media;
|
||
12 months ago
|
using Avalonia.Threading;
|
||
|
using CommunityToolkit.Mvvm.ComponentModel;
|
||
12 months ago
|
using CommunityToolkit.Mvvm.Input;
|
||
12 months ago
|
using DynamicData;
|
||
|
using DynamicData.Binding;
|
||
|
using FluentAvalonia.UI.Controls;
|
||
12 months ago
|
using StabilityMatrix.Avalonia.Languages;
|
||
12 months ago
|
using StabilityMatrix.Avalonia.Models.HuggingFace;
|
||
|
using StabilityMatrix.Avalonia.Services;
|
||
|
using StabilityMatrix.Avalonia.ViewModels.Base;
|
||
|
using StabilityMatrix.Avalonia.ViewModels.HuggingFacePage;
|
||
|
using StabilityMatrix.Core.Attributes;
|
||
|
using StabilityMatrix.Core.Extensions;
|
||
|
using StabilityMatrix.Core.Models;
|
||
|
using StabilityMatrix.Core.Models.FileInterfaces;
|
||
|
using StabilityMatrix.Core.Models.Progress;
|
||
|
using StabilityMatrix.Core.Services;
|
||
|
|
||
12 months ago
|
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
|
||
12 months ago
|
|
||
|
[View(typeof(Views.HuggingFacePage))]
|
||
|
[Singleton]
|
||
12 months ago
|
public partial class HuggingFacePageViewModel : TabViewModelBase
|
||
12 months ago
|
{
|
||
|
private readonly ITrackedDownloadService trackedDownloadService;
|
||
|
private readonly ISettingsManager settingsManager;
|
||
|
private readonly INotificationService notificationService;
|
||
12 months ago
|
|
||
12 months ago
|
public SourceCache<HuggingfaceItem, string> ItemsCache { get; } =
|
||
|
new(i => i.RepositoryPath + i.ModelName);
|
||
|
|
||
|
public IObservableCollection<CategoryViewModel> Categories { get; set; } =
|
||
|
new ObservableCollectionExtended<CategoryViewModel>();
|
||
|
|
||
|
public string DownloadPercentText =>
|
||
|
Math.Abs(TotalProgress.Percentage - 100f) < 0.001f
|
||
|
? "Download Complete"
|
||
|
: $"Downloading {TotalProgress.Percentage:0}%";
|
||
|
|
||
|
[ObservableProperty]
|
||
|
private int numSelected;
|
||
|
|
||
|
private ConcurrentDictionary<Guid, ProgressReport> progressReports = new();
|
||
|
|
||
|
[ObservableProperty]
|
||
|
[NotifyPropertyChangedFor(nameof(DownloadPercentText))]
|
||
|
private ProgressReport totalProgress;
|
||
|
|
||
|
private readonly DispatcherTimer progressTimer =
|
||
|
new() { Interval = TimeSpan.FromMilliseconds(100) };
|
||
|
|
||
|
public HuggingFacePageViewModel(
|
||
|
ITrackedDownloadService trackedDownloadService,
|
||
|
ISettingsManager settingsManager,
|
||
|
INotificationService notificationService
|
||
|
)
|
||
|
{
|
||
|
this.trackedDownloadService = trackedDownloadService;
|
||
|
this.settingsManager = settingsManager;
|
||
|
this.notificationService = notificationService;
|
||
|
|
||
|
ItemsCache
|
||
|
.Connect()
|
||
|
.DeferUntilLoaded()
|
||
|
.Group(i => i.ModelCategory)
|
||
|
.Transform(
|
||
|
g =>
|
||
|
new CategoryViewModel(g.Cache.Items)
|
||
|
{
|
||
|
Title = g.Key.GetDescription() ?? g.Key.ToString()
|
||
|
}
|
||
|
)
|
||
|
.SortBy(vm => vm.Title)
|
||
|
.Bind(Categories)
|
||
|
.WhenAnyPropertyChanged()
|
||
|
.Subscribe(_ => NumSelected = Categories.Sum(c => c.NumSelected));
|
||
|
|
||
|
progressTimer.Tick += (_, _) =>
|
||
|
{
|
||
|
var currentSum = 0ul;
|
||
|
var totalSum = 0ul;
|
||
|
foreach (var progress in progressReports.Values)
|
||
|
{
|
||
|
currentSum += progress.Current ?? 0;
|
||
|
totalSum += progress.Total ?? 0;
|
||
|
}
|
||
|
|
||
|
TotalProgress = new ProgressReport(current: currentSum, total: totalSum);
|
||
|
};
|
||
12 months ago
|
|
||
12 months ago
|
using var reader = new StreamReader(Assets.HfPackagesJson.Open());
|
||
|
var packages =
|
||
|
JsonSerializer.Deserialize<IReadOnlyList<HuggingfaceItem>>(
|
||
|
reader.ReadToEnd(),
|
||
|
new JsonSerializerOptions { Converters = { new JsonStringEnumConverter() } }
|
||
|
) ?? throw new InvalidOperationException("Failed to read hf-packages.json");
|
||
|
|
||
|
ItemsCache.EditDiff(packages, (a, b) => a.RepositoryPath == b.RepositoryPath);
|
||
12 months ago
|
|
||
|
var btn = new CommandBarButton();
|
||
|
var obs = btn.GetObservable(InputElement.IsEnabledProperty)
|
||
|
.Select(_ => NumSelected > 0);
|
||
|
btn.Bind(InputElement.IsEnabledProperty, obs);
|
||
|
btn.Command = ImportSelectedCommand;
|
||
|
btn.Foreground = new SolidColorBrush(Colors.Lime);
|
||
|
btn.IconSource = new SymbolIconSource { Symbol = Symbol.Download };
|
||
|
PrimaryCommands.Add(btn);
|
||
12 months ago
|
}
|
||
|
|
||
|
public void ClearSelection()
|
||
|
{
|
||
|
foreach (var category in Categories)
|
||
|
{
|
||
|
category.IsChecked = true;
|
||
|
category.IsChecked = false;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
public void SelectAll()
|
||
|
{
|
||
|
foreach (var category in Categories)
|
||
|
{
|
||
|
category.IsChecked = true;
|
||
|
}
|
||
|
}
|
||
|
|
||
12 months ago
|
[RelayCommand]
|
||
|
private async Task ImportSelected()
|
||
12 months ago
|
{
|
||
|
var selected = Categories.SelectMany(c => c.Items).Where(i => i.IsSelected).ToArray();
|
||
|
|
||
|
foreach (var viewModel in selected)
|
||
|
{
|
||
|
foreach (var file in viewModel.Item.Files)
|
||
|
{
|
||
|
var url =
|
||
|
$"https://huggingface.co/{viewModel.Item.RepositoryPath}/resolve/main/{file}?download=true";
|
||
|
var sharedFolderType = viewModel.Item.ModelCategory.ConvertTo<SharedFolderType>();
|
||
|
var downloadPath = new FilePath(
|
||
|
Path.Combine(
|
||
|
settingsManager.ModelsDirectory,
|
||
|
sharedFolderType.GetDescription() ?? sharedFolderType.ToString(),
|
||
|
viewModel.Item.Subfolder ?? string.Empty,
|
||
|
file
|
||
|
)
|
||
|
);
|
||
|
Directory.CreateDirectory(downloadPath.Directory);
|
||
|
var download = trackedDownloadService.NewDownload(url, downloadPath);
|
||
|
download.ProgressUpdate += DownloadOnProgressUpdate;
|
||
|
download.Start();
|
||
|
|
||
|
await Task.Delay(Random.Shared.Next(50, 100));
|
||
|
}
|
||
|
}
|
||
|
progressTimer.Start();
|
||
|
}
|
||
|
|
||
|
private void DownloadOnProgressUpdate(object? sender, ProgressReport e)
|
||
|
{
|
||
|
if (sender is not TrackedDownload trackedDownload)
|
||
|
return;
|
||
|
|
||
|
progressReports[trackedDownload.Id] = e;
|
||
|
}
|
||
|
|
||
|
partial void OnTotalProgressChanged(ProgressReport value)
|
||
|
{
|
||
|
if (Math.Abs(value.Percentage - 100) < 0.001f)
|
||
|
{
|
||
|
notificationService.Show(
|
||
|
"Download complete",
|
||
|
"All selected models have been downloaded.",
|
||
|
NotificationType.Success
|
||
|
);
|
||
|
progressTimer.Stop();
|
||
|
DelayedClearProgress(TimeSpan.FromSeconds(1.5));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private void DelayedClearProgress(TimeSpan delay)
|
||
|
{
|
||
|
Task.Delay(delay)
|
||
|
.ContinueWith(_ =>
|
||
|
{
|
||
|
TotalProgress = new ProgressReport(0, 0);
|
||
|
});
|
||
|
}
|
||
12 months ago
|
|
||
|
public override string Header => "Hugging Face";
|
||
12 months ago
|
}
|