@ -1,4 +1,5 @@
import psutil
import psutil
import logging
from enum import Enum
from enum import Enum
from comfy . cli_args import args
from comfy . cli_args import args
import comfy . utils
import comfy . utils
@ -29,7 +30,7 @@ lowvram_available = True
xpu_available = False
xpu_available = False
if args . deterministic :
if args . deterministic :
print ( " Using deterministic algorithms for pytorch " )
logging . warning ( " Using deterministic algorithms for pytorch " )
torch . use_deterministic_algorithms ( True , warn_only = True )
torch . use_deterministic_algorithms ( True , warn_only = True )
directml_enabled = False
directml_enabled = False
@ -41,7 +42,7 @@ if args.directml is not None:
directml_device = torch_directml . device ( )
directml_device = torch_directml . device ( )
else :
else :
directml_device = torch_directml . device ( device_index )
directml_device = torch_directml . device ( device_index )
print ( " Using directml with device: " , torch_directml . device_name ( device_index ) )
logging . warning ( " Using directml with device: {} " . format ( torch_directml . device_name ( device_index ) ) )
# torch_directml.disable_tiled_resources(True)
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
@ -117,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory ( get_torch_device ( ) ) / ( 1024 * 1024 )
total_vram = get_total_memory ( get_torch_device ( ) ) / ( 1024 * 1024 )
total_ram = psutil . virtual_memory ( ) . total / ( 1024 * 1024 )
total_ram = psutil . virtual_memory ( ) . total / ( 1024 * 1024 )
print ( " Total VRAM {:0.0f} MB, total RAM {:0.0f} MB " . format ( total_vram , total_ram ) )
logging . warning ( " Total VRAM {:0.0f} MB, total RAM {:0.0f} MB " . format ( total_vram , total_ram ) )
if not args . normalvram and not args . cpu :
if not args . normalvram and not args . cpu :
if lowvram_available and total_vram < = 4096 :
if lowvram_available and total_vram < = 4096 :
print ( " Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don ' t want this use: --normalvram " )
logging . warning ( " Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don ' t want this use: --normalvram " )
set_vram_to = VRAMState . LOW_VRAM
set_vram_to = VRAMState . LOW_VRAM
try :
try :
@ -143,12 +144,10 @@ else:
pass
pass
try :
try :
XFORMERS_VERSION = xformers . version . __version__
XFORMERS_VERSION = xformers . version . __version__
print ( " xformers version: " , XFORMERS_VERSION )
logging . warning ( " xformers version: {} " . format ( XFORMERS_VERSION ) )
if XFORMERS_VERSION . startswith ( " 0.0.18 " ) :
if XFORMERS_VERSION . startswith ( " 0.0.18 " ) :
print ( )
logging . warning ( " \n WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images. " )
print ( " WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images. " )
logging . warning ( " Please downgrade or upgrade xformers to a different version. \n " )
print ( " Please downgrade or upgrade xformers to a different version. " )
print ( )
XFORMERS_ENABLED_VAE = False
XFORMERS_ENABLED_VAE = False
except :
except :
pass
pass
@ -213,11 +212,11 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False
FORCE_FP32 = False
FORCE_FP16 = False
FORCE_FP16 = False
if args . force_fp32 :
if args . force_fp32 :
print ( " Forcing FP32, if this improves things please report it. " )
logging . warning ( " Forcing FP32, if this improves things please report it. " )
FORCE_FP32 = True
FORCE_FP32 = True
if args . force_fp16 :
if args . force_fp16 :
print ( " Forcing FP16. " )
logging . warning ( " Forcing FP16. " )
FORCE_FP16 = True
FORCE_FP16 = True
if lowvram_available :
if lowvram_available :
@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState . MPS :
if cpu_state == CPUState . MPS :
vram_state = VRAMState . SHARED
vram_state = VRAMState . SHARED
print ( f " Set vram state to: { vram_state . name } " )
logging . warning ( f " Set vram state to: { vram_state . name } " )
DISABLE_SMART_MEMORY = args . disable_smart_memory
DISABLE_SMART_MEMORY = args . disable_smart_memory
if DISABLE_SMART_MEMORY :
if DISABLE_SMART_MEMORY :
print ( " Disabling smart memory management " )
logging . warning ( " Disabling smart memory management " )
def get_torch_device_name ( device ) :
def get_torch_device_name ( device ) :
if hasattr ( device , ' type ' ) :
if hasattr ( device , ' type ' ) :
@ -254,11 +253,11 @@ def get_torch_device_name(device):
return " CUDA {} : {} " . format ( device , torch . cuda . get_device_name ( device ) )
return " CUDA {} : {} " . format ( device , torch . cuda . get_device_name ( device ) )
try :
try :
print ( " Device: " , get_torch_device_name ( get_torch_device ( ) ) )
logging . warning ( " Device: {} " . format ( get_torch_device_name ( get_torch_device ( ) ) ) )
except :
except :
print ( " Could not pick default device. " )
logging . warning ( " Could not pick default device. " )
print ( " VAE dtype: " , VAE_DTYPE )
logging . warning ( " VAE dtype: {} " . format ( VAE_DTYPE ) )
current_loaded_models = [ ]
current_loaded_models = [ ]
@ -301,7 +300,7 @@ class LoadedModel:
raise e
raise e
if lowvram_model_memory > 0 :
if lowvram_model_memory > 0 :
print ( " loading in lowvram mode " , lowvram_model_memory / ( 1024 * 1024 ) )
logging . warning ( " loading in lowvram mode {} " . format ( lowvram_model_memory / ( 1024 * 1024 ) ) )
mem_counter = 0
mem_counter = 0
for m in self . real_model . modules ( ) :
for m in self . real_model . modules ( ) :
if hasattr ( m , " comfy_cast_weights " ) :
if hasattr ( m , " comfy_cast_weights " ) :
@ -314,7 +313,7 @@ class LoadedModel:
elif hasattr ( m , " weight " ) : #only modules with comfy_cast_weights can be set to lowvram mode
elif hasattr ( m , " weight " ) : #only modules with comfy_cast_weights can be set to lowvram mode
m . to ( self . device )
m . to ( self . device )
mem_counter + = module_size ( m )
mem_counter + = module_size ( m )
print ( " lowvram: loaded module regularly " , m )
logging . warning ( " lowvram: loaded module regularly {} " . format ( m ) )
self . model_accelerated = True
self . model_accelerated = True
@ -348,7 +347,7 @@ def unload_model_clones(model):
to_unload = [ i ] + to_unload
to_unload = [ i ] + to_unload
for i in to_unload :
for i in to_unload :
print ( " unload clone " , i )
logging . warning ( " unload clone {} " . format ( i ) )
current_loaded_models . pop ( i ) . model_unload ( )
current_loaded_models . pop ( i ) . model_unload ( )
def free_memory ( memory_required , device , keep_loaded = [ ] ) :
def free_memory ( memory_required , device , keep_loaded = [ ] ) :
@ -390,7 +389,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded . append ( loaded_model )
models_already_loaded . append ( loaded_model )
else :
else :
if hasattr ( x , " model " ) :
if hasattr ( x , " model " ) :
print ( f " Requested to load { x . model . __class__ . __name__ } " )
logging . warning ( f " Requested to load { x . model . __class__ . __name__ } " )
models_to_load . append ( loaded_model )
models_to_load . append ( loaded_model )
if len ( models_to_load ) == 0 :
if len ( models_to_load ) == 0 :
@ -400,7 +399,7 @@ def load_models_gpu(models, memory_required=0):
free_memory ( extra_mem , d , models_already_loaded )
free_memory ( extra_mem , d , models_already_loaded )
return
return
print ( f " Loading { len ( models_to_load ) } new model { ' s ' if len ( models_to_load ) > 1 else ' ' } " )
logging . warning ( f " Loading { len ( models_to_load ) } new model { ' s ' if len ( models_to_load ) > 1 else ' ' } " )
total_memory_required = { }
total_memory_required = { }
for loaded_model in models_to_load :
for loaded_model in models_to_load :