From 735ac4cf81862b21902b312930ebfc92eef63357 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jun 2023 10:11:33 -0400 Subject: [PATCH] Remove pytorch_lightning dependency. --- comfy/checkpoint_pickle.py | 13 +++++++++++++ comfy/utils.py | 3 ++- requirements.txt | 1 - 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 comfy/checkpoint_pickle.py diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py new file mode 100644 index 00000000..206551d3 --- /dev/null +++ b/comfy/checkpoint_pickle.py @@ -0,0 +1,13 @@ +import pickle + +load = pickle.load + +class Empty: + pass + +class Unpickler(pickle.Unpickler): + def find_class(self, module, name): + #TODO: safe unpickle + if module.startswith("pytorch_lightning"): + return Empty + return super().find_class(module, name) diff --git a/comfy/utils.py b/comfy/utils.py index 585ebda5..401eb803 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,6 +1,7 @@ import torch import math import struct +import comfy.checkpoint_pickle def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False): if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: - pl_sd = torch.load(ckpt, map_location="cpu") + pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: diff --git a/requirements.txt b/requirements.txt index c551f683..d632edf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ torchsde einops transformers>=4.25.1 safetensors>=0.3.0 -pytorch_lightning aiohttp accelerate pyyaml