@ -2,6 +2,7 @@ import psutil
from enum import Enum
from comfy . cli_args import args
import torch
import sys
class VRAMState ( Enum ) :
DISABLED = 0 #No vram present: no need to move models to vram
@ -221,132 +222,161 @@ except:
print ( " Could not pick default device. " )
current_loaded_model = None
current_gpu_controlnets = [ ]
current_loaded_models = [ ]
model_accelerated = False
class LoadedModel :
def __init__ ( self , model ) :
self . model = model
self . model_accelerated = False
self . device = model . load_device
def model_memory ( self ) :
return self . model . model_size ( )
def unload_model ( ) :
global current_loaded_model
global model_accelerated
global current_gpu_controlnets
global vram_state
def model_memory_required ( self , device ) :
if device == self . model . current_device :
return 0
else :
return self . model_memory ( )
if current_loaded_model is not None :
if model_accelerated :
accelerate . hooks . remove_hook_from_submodules ( current_loaded_model . model )
model_accelerated = Fals e
def model_load ( self , lowvram_model_memory = 0 ) :
patch_model_to = None
if lowvram_model_memory == 0 :
patch_model_to = self . devic e
current_loaded_model . unpatch_model ( )
current_loaded_model . model . to ( current_loaded_model . offload_device )
current_loaded_model . model_patches_to ( current_loaded_model . offload_device )
current_loaded_model = None
if vram_state != VRAMState . HIGH_VRAM :
soft_empty_cache ( )
self . model . model_patches_to ( self . device )
self . model . model_patches_to ( self . model . model_dtype ( ) )
if vram_state != VRAMState . HIGH_VRAM :
if len ( current_gpu_controlnets ) > 0 :
for n in current_gpu_controlnets :
n . cpu ( )
current_gpu_controlnets = [ ]
try :
self . real_model = self . model . patch_model ( device_to = patch_model_to ) #TODO: do something with loras and offloading to CPU
except Exception as e :
self . model . unpatch_model ( self . model . offload_device )
self . model_unload ( )
raise e
def minimum_inference_memory ( ) :
return ( 768 * 1024 * 1024 )
if lowvram_model_memory > 0 :
print ( " loading in lowvram mode " , lowvram_model_memory / ( 1024 * 1024 ) )
device_map = accelerate . infer_auto_device_map ( self . real_model , max_memory = { 0 : " {} MiB " . format ( lowvram_model_memory / / ( 1024 * 1024 ) ) , " cpu " : " 16GiB " } )
accelerate . dispatch_model ( self . real_model , device_map = device_map , main_device = self . device )
self . model_accelerated = True
def load_model_gpu ( model ) :
global current_loaded_model
global vram_state
global model_accelerated
return self . real_model
if model is current_loaded_model :
return
unload_model ( )
def model_unload ( self ) :
if self . model_accelerated :
accelerate . hooks . remove_hook_from_submodules ( self . real_model )
self . model_accelerated = False
torch_dev = model . load_device
model . model_patches_to ( torch_dev )
model . model_patches_to ( model . model_dtype ( ) )
current_loaded_model = model
self . model . unpatch_model ( self . model . offload_device )
self . model . model_patches_to ( self . model . offload_device )
if is_device_cpu ( torch_dev ) :
vram_set_state = VRAMState . DISABLED
else :
vram_set_state = vram_state
if lowvram_available and ( vram_set_state == VRAMState . LOW_VRAM or vram_set_state == VRAMState . NORMAL_VRAM ) :
model_size = model . model_size ( )
current_free_mem = get_free_memory ( torch_dev )
lowvram_model_memory = int ( max ( 256 * ( 1024 * 1024 ) , ( current_free_mem - 1024 * ( 1024 * 1024 ) ) / 1.3 ) )
if model_size > ( current_free_mem - minimum_inference_memory ( ) ) : #only switch to lowvram if really necessary
vram_set_state = VRAMState . LOW_VRAM
real_model = model . model
patch_model_to = None
if vram_set_state == VRAMState . DISABLED :
pass
elif vram_set_state == VRAMState . NORMAL_VRAM or vram_set_state == VRAMState . HIGH_VRAM or vram_set_state == VRAMState . SHARED :
model_accelerated = False
patch_model_to = torch_dev
def __eq__ ( self , other ) :
return self . model is other . model
try :
real_model = model . patch_model ( device_to = patch_model_to )
except Exception as e :
model . unpatch_model ( )
unload_model ( )
raise e
if patch_model_to is not None :
real_model . to ( torch_dev )
if vram_set_state == VRAMState . NO_VRAM :
device_map = accelerate . infer_auto_device_map ( real_model , max_memory = { 0 : " 256MiB " , " cpu " : " 16GiB " } )
accelerate . dispatch_model ( real_model , device_map = device_map , main_device = torch_dev )
model_accelerated = True
elif vram_set_state == VRAMState . LOW_VRAM :
device_map = accelerate . infer_auto_device_map ( real_model , max_memory = { 0 : " {} MiB " . format ( lowvram_model_memory / / ( 1024 * 1024 ) ) , " cpu " : " 16GiB " } )
accelerate . dispatch_model ( real_model , device_map = device_map , main_device = torch_dev )
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu ( control_models ) :
global current_gpu_controlnets
def minimum_inference_memory ( ) :
return ( 1024 * 1024 * 1024 )
def unload_model_clones ( model ) :
to_unload = [ ]
for i in range ( len ( current_loaded_models ) ) :
if model . is_clone ( current_loaded_models [ i ] . model ) :
to_unload = [ i ] + to_unload
for i in to_unload :
print ( " unload clone " , i )
current_loaded_models . pop ( i ) . model_unload ( )
def free_memory ( memory_required , device , keep_loaded = [ ] ) :
unloaded_model = False
for i in range ( len ( current_loaded_models ) - 1 , - 1 , - 1 ) :
current_free_mem = get_free_memory ( device )
if current_free_mem > memory_required :
break
shift_model = current_loaded_models [ i ]
if shift_model . device == device :
if shift_model not in keep_loaded :
current_loaded_models . pop ( i ) . model_unload ( )
unloaded_model = True
if unloaded_model :
soft_empty_cache ( )
def load_models_gpu ( models , memory_required = 0 ) :
global vram_state
if vram_state == VRAMState . DISABLED :
return
if vram_state == VRAMState . LOW_VRAM or vram_state == VRAMState . NO_VRAM :
for m in control_models :
if hasattr ( m , ' set_lowvram ' ) :
m . set_lowvram ( True )
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
inference_memory = minimum_inference_memory ( )
extra_mem = max ( inference_memory , memory_required )
models_to_load = [ ]
models_already_loaded = [ ]
for x in models :
loaded_model = LoadedModel ( x )
if loaded_model in current_loaded_models :
index = current_loaded_models . index ( loaded_model )
current_loaded_models . insert ( 0 , current_loaded_models . pop ( index ) )
models_already_loaded . append ( loaded_model )
else :
models_to_load . append ( loaded_model )
if len ( models_to_load ) == 0 :
devs = set ( map ( lambda a : a . device , models_already_loaded ) )
for d in devs :
if d != torch . device ( " cpu " ) :
free_memory ( extra_mem , d , models_already_loaded )
return
models = [ ]
for m in control_models :
models + = m . get_models ( )
print ( " loading new " )
for m in current_gpu_controlnets :
if m not in models :
m . cpu ( )
total_memory_required = { }
for loaded_model in models_to_load :
unload_model_clones ( loaded_model . model )
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
device = get_torch_device ( )
current_gpu_controlnets = [ ]
for m in models :
current_gpu_controlnets . append ( m . to ( device ) )
for device in total_memory_required :
if device != torch . device ( " cpu " ) :
free_memory ( total_memory_required [ device ] * 1.3 + extra_mem , device , models_already_loaded )
for loaded_model in models_to_load :
model = loaded_model . model
torch_dev = model . load_device
if is_device_cpu ( torch_dev ) :
vram_set_state = VRAMState . DISABLED
else :
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and ( vram_set_state == VRAMState . LOW_VRAM or vram_set_state == VRAMState . NORMAL_VRAM ) :
model_size = loaded_model . model_memory_required ( torch_dev )
current_free_mem = get_free_memory ( torch_dev )
lowvram_model_memory = int ( max ( 256 * ( 1024 * 1024 ) , ( current_free_mem - 1024 * ( 1024 * 1024 ) ) / 1.3 ) )
if model_size > ( current_free_mem - inference_memory ) : #only switch to lowvram if really necessary
vram_set_state = VRAMState . LOW_VRAM
else :
lowvram_model_memory = 0
def load_if_low_vram ( model ) :
global vram_state
if vram_state == VRAMState . LOW_VRAM or vram_state == VRAMState . NO_VRAM :
return model . to ( get_torch_device ( ) )
return model
if vram_set_state == VRAMState . NO_VRAM :
lowvram_model_memory = 256 * 1024 * 1024
def unload_if_low_vram ( model ) :
global vram_state
if vram_state == VRAMState . LOW_VRAM or vram_state == VRAMState . NO_VRAM :
return model . cpu ( )
return model
cur_loaded_model = loaded_model . model_load ( lowvram_model_memory )
current_loaded_models . insert ( 0 , loaded_model )
return
def load_model_gpu ( model ) :
return load_models_gpu ( [ model ] )
def cleanup_models ( ) :
to_delete = [ ]
for i in range ( len ( current_loaded_models ) ) :
print ( sys . getrefcount ( current_loaded_models [ i ] . model ) )
if sys . getrefcount ( current_loaded_models [ i ] . model ) < = 2 :
to_delete = [ i ] + to_delete
for i in to_delete :
x = current_loaded_models . pop ( i )
x . model_unload ( )
del x
def unet_offload_device ( ) :
if vram_state == VRAMState . HIGH_VRAM :
@ -354,6 +384,21 @@ def unet_offload_device():
else :
return torch . device ( " cpu " )
def unet_inital_load_device ( parameters , dtype ) :
torch_dev = get_torch_device ( )
if vram_state == VRAMState . HIGH_VRAM :
return torch_dev
cpu_dev = torch . device ( " cpu " )
model_size = dtype . itemsize * parameters
mem_dev = get_free_memory ( torch_dev )
mem_cpu = get_free_memory ( cpu_dev )
if mem_dev > mem_cpu and model_size < mem_dev :
return torch_dev
else :
return cpu_dev
def text_encoder_offload_device ( ) :
if args . gpu_only :
return get_torch_device ( )
@ -456,6 +501,13 @@ def get_free_memory(dev=None, torch_free_too=False):
else :
return mem_free_total
def batch_area_memory ( area ) :
if xformers_enabled ( ) or pytorch_attention_flash_attention ( ) :
#TODO: these formulas are copied from maximum_batch_area below
return ( area / 20 ) * ( 1024 * 1024 )
else :
return ( ( ( area * 0.6 ) / 0.9 ) + 1024 ) * ( 1024 * 1024 )
def maximum_batch_area ( ) :
global vram_state
if vram_state == VRAMState . NO_VRAM :