|
|
@ -30,7 +30,7 @@ lowvram_available = True |
|
|
|
xpu_available = False |
|
|
|
xpu_available = False |
|
|
|
|
|
|
|
|
|
|
|
if args.deterministic: |
|
|
|
if args.deterministic: |
|
|
|
logging.warning("Using deterministic algorithms for pytorch") |
|
|
|
logging.info("Using deterministic algorithms for pytorch") |
|
|
|
torch.use_deterministic_algorithms(True, warn_only=True) |
|
|
|
torch.use_deterministic_algorithms(True, warn_only=True) |
|
|
|
|
|
|
|
|
|
|
|
directml_enabled = False |
|
|
|
directml_enabled = False |
|
|
@ -42,7 +42,7 @@ if args.directml is not None: |
|
|
|
directml_device = torch_directml.device() |
|
|
|
directml_device = torch_directml.device() |
|
|
|
else: |
|
|
|
else: |
|
|
|
directml_device = torch_directml.device(device_index) |
|
|
|
directml_device = torch_directml.device(device_index) |
|
|
|
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index))) |
|
|
|
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) |
|
|
|
# torch_directml.disable_tiled_resources(True) |
|
|
|
# torch_directml.disable_tiled_resources(True) |
|
|
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. |
|
|
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. |
|
|
|
|
|
|
|
|
|
|
@ -118,7 +118,7 @@ def get_total_memory(dev=None, torch_total_too=False): |
|
|
|
|
|
|
|
|
|
|
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) |
|
|
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) |
|
|
|
total_ram = psutil.virtual_memory().total / (1024 * 1024) |
|
|
|
total_ram = psutil.virtual_memory().total / (1024 * 1024) |
|
|
|
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) |
|
|
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) |
|
|
|
if not args.normalvram and not args.cpu: |
|
|
|
if not args.normalvram and not args.cpu: |
|
|
|
if lowvram_available and total_vram <= 4096: |
|
|
|
if lowvram_available and total_vram <= 4096: |
|
|
|
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") |
|
|
|
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") |
|
|
@ -144,7 +144,7 @@ else: |
|
|
|
pass |
|
|
|
pass |
|
|
|
try: |
|
|
|
try: |
|
|
|
XFORMERS_VERSION = xformers.version.__version__ |
|
|
|
XFORMERS_VERSION = xformers.version.__version__ |
|
|
|
logging.warning("xformers version: {}".format(XFORMERS_VERSION)) |
|
|
|
logging.info("xformers version: {}".format(XFORMERS_VERSION)) |
|
|
|
if XFORMERS_VERSION.startswith("0.0.18"): |
|
|
|
if XFORMERS_VERSION.startswith("0.0.18"): |
|
|
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") |
|
|
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") |
|
|
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n") |
|
|
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n") |
|
|
@ -212,11 +212,11 @@ elif args.highvram or args.gpu_only: |
|
|
|
FORCE_FP32 = False |
|
|
|
FORCE_FP32 = False |
|
|
|
FORCE_FP16 = False |
|
|
|
FORCE_FP16 = False |
|
|
|
if args.force_fp32: |
|
|
|
if args.force_fp32: |
|
|
|
logging.warning("Forcing FP32, if this improves things please report it.") |
|
|
|
logging.info("Forcing FP32, if this improves things please report it.") |
|
|
|
FORCE_FP32 = True |
|
|
|
FORCE_FP32 = True |
|
|
|
|
|
|
|
|
|
|
|
if args.force_fp16: |
|
|
|
if args.force_fp16: |
|
|
|
logging.warning("Forcing FP16.") |
|
|
|
logging.info("Forcing FP16.") |
|
|
|
FORCE_FP16 = True |
|
|
|
FORCE_FP16 = True |
|
|
|
|
|
|
|
|
|
|
|
if lowvram_available: |
|
|
|
if lowvram_available: |
|
|
@ -230,12 +230,12 @@ if cpu_state != CPUState.GPU: |
|
|
|
if cpu_state == CPUState.MPS: |
|
|
|
if cpu_state == CPUState.MPS: |
|
|
|
vram_state = VRAMState.SHARED |
|
|
|
vram_state = VRAMState.SHARED |
|
|
|
|
|
|
|
|
|
|
|
logging.warning(f"Set vram state to: {vram_state.name}") |
|
|
|
logging.info(f"Set vram state to: {vram_state.name}") |
|
|
|
|
|
|
|
|
|
|
|
DISABLE_SMART_MEMORY = args.disable_smart_memory |
|
|
|
DISABLE_SMART_MEMORY = args.disable_smart_memory |
|
|
|
|
|
|
|
|
|
|
|
if DISABLE_SMART_MEMORY: |
|
|
|
if DISABLE_SMART_MEMORY: |
|
|
|
logging.warning("Disabling smart memory management") |
|
|
|
logging.info("Disabling smart memory management") |
|
|
|
|
|
|
|
|
|
|
|
def get_torch_device_name(device): |
|
|
|
def get_torch_device_name(device): |
|
|
|
if hasattr(device, 'type'): |
|
|
|
if hasattr(device, 'type'): |
|
|
@ -253,11 +253,11 @@ def get_torch_device_name(device): |
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
try: |
|
|
|
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device()))) |
|
|
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) |
|
|
|
except: |
|
|
|
except: |
|
|
|
logging.warning("Could not pick default device.") |
|
|
|
logging.warning("Could not pick default device.") |
|
|
|
|
|
|
|
|
|
|
|
logging.warning("VAE dtype: {}".format(VAE_DTYPE)) |
|
|
|
logging.info("VAE dtype: {}".format(VAE_DTYPE)) |
|
|
|
|
|
|
|
|
|
|
|
current_loaded_models = [] |
|
|
|
current_loaded_models = [] |
|
|
|
|
|
|
|
|
|
|
@ -300,7 +300,7 @@ class LoadedModel: |
|
|
|
raise e |
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
if lowvram_model_memory > 0: |
|
|
|
if lowvram_model_memory > 0: |
|
|
|
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) |
|
|
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) |
|
|
|
mem_counter = 0 |
|
|
|
mem_counter = 0 |
|
|
|
for m in self.real_model.modules(): |
|
|
|
for m in self.real_model.modules(): |
|
|
|
if hasattr(m, "comfy_cast_weights"): |
|
|
|
if hasattr(m, "comfy_cast_weights"): |
|
|
@ -347,7 +347,7 @@ def unload_model_clones(model): |
|
|
|
to_unload = [i] + to_unload |
|
|
|
to_unload = [i] + to_unload |
|
|
|
|
|
|
|
|
|
|
|
for i in to_unload: |
|
|
|
for i in to_unload: |
|
|
|
logging.warning("unload clone {}".format(i)) |
|
|
|
logging.debug("unload clone {}".format(i)) |
|
|
|
current_loaded_models.pop(i).model_unload() |
|
|
|
current_loaded_models.pop(i).model_unload() |
|
|
|
|
|
|
|
|
|
|
|
def free_memory(memory_required, device, keep_loaded=[]): |
|
|
|
def free_memory(memory_required, device, keep_loaded=[]): |
|
|
@ -389,7 +389,7 @@ def load_models_gpu(models, memory_required=0): |
|
|
|
models_already_loaded.append(loaded_model) |
|
|
|
models_already_loaded.append(loaded_model) |
|
|
|
else: |
|
|
|
else: |
|
|
|
if hasattr(x, "model"): |
|
|
|
if hasattr(x, "model"): |
|
|
|
logging.warning(f"Requested to load {x.model.__class__.__name__}") |
|
|
|
logging.info(f"Requested to load {x.model.__class__.__name__}") |
|
|
|
models_to_load.append(loaded_model) |
|
|
|
models_to_load.append(loaded_model) |
|
|
|
|
|
|
|
|
|
|
|
if len(models_to_load) == 0: |
|
|
|
if len(models_to_load) == 0: |
|
|
@ -399,7 +399,7 @@ def load_models_gpu(models, memory_required=0): |
|
|
|
free_memory(extra_mem, d, models_already_loaded) |
|
|
|
free_memory(extra_mem, d, models_already_loaded) |
|
|
|
return |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") |
|
|
|
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") |
|
|
|
|
|
|
|
|
|
|
|
total_memory_required = {} |
|
|
|
total_memory_required = {} |
|
|
|
for loaded_model in models_to_load: |
|
|
|
for loaded_model in models_to_load: |
|
|
|