|
|
|
@ -6,6 +6,10 @@ def load_torch_file(ckpt, safe_load=False):
|
|
|
|
|
import safetensors.torch |
|
|
|
|
sd = safetensors.torch.load_file(ckpt, device="cpu") |
|
|
|
|
else: |
|
|
|
|
if safe_load: |
|
|
|
|
if not 'weights_only' in torch.load.__code__.co_varnames: |
|
|
|
|
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") |
|
|
|
|
safe_load = False |
|
|
|
|
if safe_load: |
|
|
|
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) |
|
|
|
|
else: |
|
|
|
|