Multi-Platform Package Manager for Stable Diffusion
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

201 lines
6.5 KiB

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reactive.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using Avalonia;
using Avalonia.Controls.Notifications;
using Avalonia.Data;
using Avalonia.Input;
using Avalonia.Markup.Xaml.MarkupExtensions;
using Avalonia.Media;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using DynamicData;
using DynamicData.Binding;
using FluentAvalonia.UI.Controls;
using StabilityMatrix.Avalonia.Languages;
using StabilityMatrix.Avalonia.Models.HuggingFace;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Avalonia.ViewModels.HuggingFacePage;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Models;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
using StabilityMatrix.Core.Services;
namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser;
[View(typeof(Views.HuggingFacePage))]
[Singleton]
public partial class HuggingFacePageViewModel : TabViewModelBase
{
private readonly ITrackedDownloadService trackedDownloadService;
private readonly ISettingsManager settingsManager;
private readonly INotificationService notificationService;
public SourceCache<HuggingfaceItem, string> ItemsCache { get; } =
new(i => i.RepositoryPath + i.ModelName);
public IObservableCollection<CategoryViewModel> Categories { get; set; } =
new ObservableCollectionExtended<CategoryViewModel>();
public string DownloadPercentText =>
Math.Abs(TotalProgress.Percentage - 100f) < 0.001f
? "Download Complete"
: $"Downloading {TotalProgress.Percentage:0}%";
[ObservableProperty]
private int numSelected;
private ConcurrentDictionary<Guid, ProgressReport> progressReports = new();
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(DownloadPercentText))]
private ProgressReport totalProgress;
private readonly DispatcherTimer progressTimer = new() { Interval = TimeSpan.FromMilliseconds(100) };
public HuggingFacePageViewModel(
ITrackedDownloadService trackedDownloadService,
ISettingsManager settingsManager,
INotificationService notificationService
)
{
this.trackedDownloadService = trackedDownloadService;
this.settingsManager = settingsManager;
this.notificationService = notificationService;
ItemsCache
.Connect()
.DeferUntilLoaded()
.Group(i => i.ModelCategory)
.Transform(
g =>
new CategoryViewModel(g.Cache.Items)
{
Title = g.Key.GetDescription() ?? g.Key.ToString()
}
)
.SortBy(vm => vm.Title)
.Bind(Categories)
.WhenAnyPropertyChanged()
.Subscribe(_ => NumSelected = Categories.Sum(c => c.NumSelected));
progressTimer.Tick += (_, _) =>
{
var currentSum = 0ul;
var totalSum = 0ul;
foreach (var progress in progressReports.Values)
{
currentSum += progress.Current ?? 0;
totalSum += progress.Total ?? 0;
}
TotalProgress = new ProgressReport(current: currentSum, total: totalSum);
};
}
public override void OnLoaded()
{
if (ItemsCache.Count > 0)
return;
using var reader = new StreamReader(Assets.HfPackagesJson.Open());
var packages =
JsonSerializer.Deserialize<IReadOnlyList<HuggingfaceItem>>(
reader.ReadToEnd(),
new JsonSerializerOptions { Converters = { new JsonStringEnumConverter() } }
) ?? throw new InvalidOperationException("Failed to read hf-packages.json");
ItemsCache.EditDiff(packages, (a, b) => a.RepositoryPath == b.RepositoryPath);
}
public void ClearSelection()
{
foreach (var category in Categories)
{
category.IsChecked = true;
category.IsChecked = false;
}
}
public void SelectAll()
{
foreach (var category in Categories)
{
category.IsChecked = true;
}
}
[RelayCommand]
private async Task ImportSelected()
{
var selected = Categories.SelectMany(c => c.Items).Where(i => i.IsSelected).ToArray();
foreach (var viewModel in selected)
{
foreach (var file in viewModel.Item.Files)
{
var url =
$"https://huggingface.co/{viewModel.Item.RepositoryPath}/resolve/main/{file}?download=true";
var sharedFolderType = viewModel.Item.ModelCategory.ConvertTo<SharedFolderType>();
var downloadPath = new FilePath(
Path.Combine(
settingsManager.ModelsDirectory,
sharedFolderType.ToString(),
viewModel.Item.Subfolder ?? string.Empty,
file
)
);
Directory.CreateDirectory(downloadPath.Directory);
var download = trackedDownloadService.NewDownload(url, downloadPath);
download.ProgressUpdate += DownloadOnProgressUpdate;
download.Start();
await Task.Delay(Random.Shared.Next(50, 100));
}
}
progressTimer.Start();
}
private void DownloadOnProgressUpdate(object? sender, ProgressReport e)
{
if (sender is not TrackedDownload trackedDownload)
return;
progressReports[trackedDownload.Id] = e;
}
partial void OnTotalProgressChanged(ProgressReport value)
{
if (Math.Abs(value.Percentage - 100) < 0.001f)
{
notificationService.Show(
"Download complete",
"All selected models have been downloaded.",
NotificationType.Success
);
progressTimer.Stop();
DelayedClearProgress(TimeSpan.FromSeconds(1.5));
}
}
private void DelayedClearProgress(TimeSpan delay)
{
Task.Delay(delay)
.ContinueWith(_ =>
{
TotalProgress = new ProgressReport(0, 0);
});
}
public override string Header => Resources.Label_HuggingFace;
}