From 65397ce601030c9f5934b0b277a019bdaa911a5f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Mar 2024 11:37:08 -0400 Subject: [PATCH] Replace prints with logging and add --verbose argument. --- comfy/cli_args.py | 10 ++++++++++ comfy/clip_vision.py | 3 ++- comfy/controlnet.py | 18 +++++++++++------ comfy/diffusers_convert.py | 3 ++- comfy/lora.py | 3 ++- comfy/model_base.py | 9 +++++---- comfy/model_detection.py | 3 ++- comfy/model_management.py | 41 +++++++++++++++++++------------------- comfy/model_patcher.py | 18 ++++++++++------- comfy/sd.py | 33 +++++++++++++++--------------- comfy/sd1_clip.py | 9 ++++----- comfy/utils.py | 5 +++-- 12 files changed, 90 insertions(+), 65 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4bbfbfa..757fc245 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -114,6 +114,9 @@ parser.add_argument("--disable-metadata", action="store_true", help="Disable sav parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") + + if comfy.options.args_parsing: args = parser.parse_args() else: @@ -124,3 +127,10 @@ if args.windows_standalone_build: if args.disable_auto_launch: args.auto_launch = False + +import logging +logging_level = logging.WARNING +if args.verbose: + logging_level = logging.DEBUG + +logging.basicConfig(format="%(message)s", level=logging_level) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 8c77ee7a..acc86be8 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl import os import torch import json +import logging import comfy.ops import comfy.model_patcher @@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) if len(m) > 0: - print("missing clip vision:", m) + logging.warning("missing clip vision: {}".format(m)) u = set(u) keys = list(sd.keys()) for k in keys: diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 1352d141..a5e7b23f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,6 +1,7 @@ import torch import math import os +import logging import comfy.utils import comfy.model_management import comfy.model_detection @@ -367,7 +368,7 @@ def load_controlnet(ckpt_path, model=None): leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) + logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' @@ -382,7 +383,7 @@ def load_controlnet(ckpt_path, model=None): else: net = load_t2i_adapter(controlnet_data) if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) return net if controlnet_config is None: @@ -417,7 +418,7 @@ def load_controlnet(ckpt_path, model=None): cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") class WeightsLoader(torch.nn.Module): pass @@ -426,7 +427,12 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) + + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.info("unexpected controlnet keys: {}".format(unexpected)) global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] @@ -536,9 +542,9 @@ def load_t2i_adapter(t2i_data): missing, unexpected = model_ad.load_state_dict(t2i_data) if len(missing) > 0: - print("t2i missing", missing) + logging.warning("t2i missing {}".format(missing)) if len(unexpected) > 0: - print("t2i unexpected", unexpected) + logging.info("t2i unexpected {}".format(unexpected)) return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index eb561933..18398cb3 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,5 +1,6 @@ import re import torch +import logging # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - print(f"Reshaping {k} for SD format") + logging.info(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict diff --git a/comfy/lora.py b/comfy/lora.py index 21b9897a..637380d5 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -1,4 +1,5 @@ import comfy.utils +import logging LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -156,7 +157,7 @@ def load_lora(lora, to_load): for x in lora.keys(): if x not in loaded_keys: - print("lora key not loaded", x) + logging.warning("lora key not loaded: {}".format(x)) return patch_dict def model_lora_keys_clip(model, key_map={}): diff --git a/comfy/model_base.py b/comfy/model_base.py index a9de1366..a2514ca5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,4 +1,5 @@ import torch +import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -66,8 +67,8 @@ class BaseModel(torch.nn.Module): if self.adm_channels is None: self.adm_channels = 0 self.inpaint_model = False - print("model_type", model_type.name) - print("adm", self.adm_channels) + logging.warning("model_type {}".format(model_type.name)) + logging.info("adm {}".format(self.adm_channels)) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -183,10 +184,10 @@ class BaseModel(torch.nn.Module): to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: - print("unet missing:", m) + logging.warning("unet missing: {}".format(m)) if len(u) > 0: - print("unet unexpected:", u) + logging.warning("unet unexpected: {}".format(u)) del to_load return self diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 07ee8570..b7c3be30 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,5 +1,6 @@ import comfy.supported_models import comfy.supported_models_base +import logging def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -186,7 +187,7 @@ def model_config_from_unet_config(unet_config): if model_config.matches(unet_config): return model_config(unet_config) - print("no match", unet_config) + logging.error("no match {}".format(unet_config)) return None def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): diff --git a/comfy/model_management.py b/comfy/model_management.py index 9f924883..dd262e26 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,4 +1,5 @@ import psutil +import logging from enum import Enum from comfy.cli_args import args import comfy.utils @@ -29,7 +30,7 @@ lowvram_available = True xpu_available = False if args.deterministic: - print("Using deterministic algorithms for pytorch") + logging.warning("Using deterministic algorithms for pytorch") torch.use_deterministic_algorithms(True, warn_only=True) directml_enabled = False @@ -41,7 +42,7 @@ if args.directml is not None: directml_device = torch_directml.device() else: directml_device = torch_directml.device(device_index) - print("Using directml with device:", torch_directml.device_name(device_index)) + logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index))) # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. @@ -117,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False): total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) -print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) if not args.normalvram and not args.cpu: if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM try: @@ -143,12 +144,10 @@ else: pass try: XFORMERS_VERSION = xformers.version.__version__ - print("xformers version:", XFORMERS_VERSION) + logging.warning("xformers version: {}".format(XFORMERS_VERSION)) if XFORMERS_VERSION.startswith("0.0.18"): - print() - print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - print("Please downgrade or upgrade xformers to a different version.") - print() + logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") + logging.warning("Please downgrade or upgrade xformers to a different version.\n") XFORMERS_ENABLED_VAE = False except: pass @@ -213,11 +212,11 @@ elif args.highvram or args.gpu_only: FORCE_FP32 = False FORCE_FP16 = False if args.force_fp32: - print("Forcing FP32, if this improves things please report it.") + logging.warning("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True if args.force_fp16: - print("Forcing FP16.") + logging.warning("Forcing FP16.") FORCE_FP16 = True if lowvram_available: @@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU: if cpu_state == CPUState.MPS: vram_state = VRAMState.SHARED -print(f"Set vram state to: {vram_state.name}") +logging.warning(f"Set vram state to: {vram_state.name}") DISABLE_SMART_MEMORY = args.disable_smart_memory if DISABLE_SMART_MEMORY: - print("Disabling smart memory management") + logging.warning("Disabling smart memory management") def get_torch_device_name(device): if hasattr(device, 'type'): @@ -254,11 +253,11 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Device:", get_torch_device_name(get_torch_device())) + logging.warning("Device: {}".format(get_torch_device_name(get_torch_device()))) except: - print("Could not pick default device.") + logging.warning("Could not pick default device.") -print("VAE dtype:", VAE_DTYPE) +logging.warning("VAE dtype: {}".format(VAE_DTYPE)) current_loaded_models = [] @@ -301,7 +300,7 @@ class LoadedModel: raise e if lowvram_model_memory > 0: - print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) + logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) mem_counter = 0 for m in self.real_model.modules(): if hasattr(m, "comfy_cast_weights"): @@ -314,7 +313,7 @@ class LoadedModel: elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode m.to(self.device) mem_counter += module_size(m) - print("lowvram: loaded module regularly", m) + logging.warning("lowvram: loaded module regularly {}".format(m)) self.model_accelerated = True @@ -348,7 +347,7 @@ def unload_model_clones(model): to_unload = [i] + to_unload for i in to_unload: - print("unload clone", i) + logging.warning("unload clone {}".format(i)) current_loaded_models.pop(i).model_unload() def free_memory(memory_required, device, keep_loaded=[]): @@ -390,7 +389,7 @@ def load_models_gpu(models, memory_required=0): models_already_loaded.append(loaded_model) else: if hasattr(x, "model"): - print(f"Requested to load {x.model.__class__.__name__}") + logging.warning(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -400,7 +399,7 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) return - print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4a5d42b0..5e578dff 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1,6 +1,7 @@ import torch import copy import inspect +import logging import comfy.utils import comfy.model_management @@ -187,7 +188,7 @@ class ModelPatcher: model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) + logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue weight = model_sd[key] @@ -236,7 +237,7 @@ class ModelPatcher: w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": #lora/locon @@ -252,7 +253,7 @@ class ModelPatcher: try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": w1 = v[0] w2 = v[1] @@ -291,7 +292,7 @@ class ModelPatcher: try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": w1a = v[0] w1b = v[1] @@ -320,7 +321,7 @@ class ModelPatcher: try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: alpha *= v[4] / v[0].shape[0] @@ -330,9 +331,12 @@ class ModelPatcher: b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) - weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + try: + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: - print("patch type not recognized", patch_type, key) + logging.warning("patch type not recognized {} {}".format(patch_type, key)) return weight diff --git a/comfy/sd.py b/comfy/sd.py index fd5d604e..3e4b9e47 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,5 +1,6 @@ import torch from enum import Enum +import logging from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine @@ -37,7 +38,7 @@ def load_model_weights(model, sd): w = sd.pop(x) del w if len(m) > 0: - print("missing", m) + logging.warning("missing {}".format(m)) return model def load_clip_weights(model, sd): @@ -81,7 +82,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): - print("NOT LOADED", x) + logging.warning("NOT LOADED {}".format(x)) return (new_modelpatcher, new_clip) @@ -225,10 +226,10 @@ class VAE: m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: - print("Missing VAE keys", m) + logging.warning("Missing VAE keys {}".format(m)) if len(u) > 0: - print("Leftover VAE keys", u) + logging.info("Leftover VAE keys {}".format(u)) if device is None: device = model_management.vae_device() @@ -291,7 +292,7 @@ class VAE: samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) @@ -317,7 +318,7 @@ class VAE: samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) return samples @@ -393,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: - print("clip missing:", m) + logging.warning("clip missing: {}".format(m)) if len(u) > 0: - print("clip unexpected:", u) + logging.info("clip unexpected: {}".format(u)) return clip def load_gligen(ckpt_path): @@ -534,21 +535,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip = CLIP(clip_target, embedding_directory=embedding_directory) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: - print("clip missing:", m) + logging.warning("clip missing: {}".format(m)) if len(u) > 0: - print("clip unexpected:", u) + logging.info("clip unexpected {}:".format(u)) else: - print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") + logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: - print("left over keys:", left_over) + logging.info("left over keys: {}".format(left_over)) if output_model: model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") + logging.warning("loaded straight to GPU") model_management.load_model_gpu(model_patcher) return (model_patcher, clip, vae, clipvision) @@ -577,7 +578,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format if k in sd: new_sd[diffusers_keys[k]] = sd.pop(k) else: - print(diffusers_keys[k], k) + logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) @@ -588,14 +589,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: - print("left over keys in unet:", left_over) + logging.warning("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = comfy.utils.load_torch_file(unet_path) model = load_unet_state_dict(sd) if model is None: - print("ERROR UNSUPPORTED UNET", unet_path) + logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 87e3eaa4..ff6db0d2 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,6 +8,7 @@ import zipfile from . import model_management import comfy.clip_model import json +import logging def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -137,7 +138,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_temp += [next_new_token] next_new_token += 1 else: - print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) while len(tokens_temp) < len(x): tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] @@ -329,9 +330,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No else: embed = torch.load(embed_path, map_location="cpu") except Exception as e: - print(traceback.format_exc()) - print() - print("error loading embedding, skipping loading:", embedding_name) + logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) return None if embed_out is None: @@ -422,7 +421,7 @@ class SDTokenizer: embedding_name = word[len(self.embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{embedding_name} does not exist, ignoring") + logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) diff --git a/comfy/utils.py b/comfy/utils.py index 5deb14cd..8caecd86 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -5,6 +5,7 @@ import comfy.checkpoint_pickle import safetensors.torch import numpy as np from PIL import Image +import logging def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -14,14 +15,14 @@ def load_torch_file(ckpt, safe_load=False, device=None): else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: - print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + logging.info(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: