|
|
|
@ -4,6 +4,7 @@ NO_VRAM = 1
|
|
|
|
|
LOW_VRAM = 2 |
|
|
|
|
NORMAL_VRAM = 3 |
|
|
|
|
HIGH_VRAM = 4 |
|
|
|
|
MPS = 4 |
|
|
|
|
|
|
|
|
|
accelerate_enabled = False |
|
|
|
|
vram_state = NORMAL_VRAM |
|
|
|
@ -61,7 +62,8 @@ if "--novram" in sys.argv:
|
|
|
|
|
set_vram_to = NO_VRAM |
|
|
|
|
if "--highvram" in sys.argv: |
|
|
|
|
vram_state = HIGH_VRAM |
|
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
|
|
vram_state = MPS |
|
|
|
|
|
|
|
|
|
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: |
|
|
|
|
try: |
|
|
|
@ -79,7 +81,7 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
|
|
|
|
if "--cpu" in sys.argv: |
|
|
|
|
vram_state = CPU |
|
|
|
|
|
|
|
|
|
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state]) |
|
|
|
|
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_loaded_model = None |
|
|
|
@ -128,6 +130,12 @@ def load_model_gpu(model):
|
|
|
|
|
current_loaded_model = model |
|
|
|
|
if vram_state == CPU: |
|
|
|
|
pass |
|
|
|
|
elif vram_state == MPS: |
|
|
|
|
# print(inspect.getmro(real_model.__class__)) |
|
|
|
|
# print(dir(real_model)) |
|
|
|
|
mps_device = torch.device("mps") |
|
|
|
|
real_model.to(mps_device) |
|
|
|
|
pass |
|
|
|
|
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: |
|
|
|
|
model_accelerated = False |
|
|
|
|
real_model.cuda() |
|
|
|
@ -146,6 +154,9 @@ def load_controlnet_gpu(models):
|
|
|
|
|
global vram_state |
|
|
|
|
if vram_state == CPU: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
if vram_state == MPS: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
if vram_state == LOW_VRAM or vram_state == NO_VRAM: |
|
|
|
|
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after |
|
|
|
@ -173,6 +184,8 @@ def unload_if_low_vram(model):
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
|
|
if vram_state == MPS: |
|
|
|
|
return torch.device("mps") |
|
|
|
|
if vram_state == CPU: |
|
|
|
|
return torch.device("cpu") |
|
|
|
|
else: |
|
|
|
@ -195,7 +208,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|
|
|
|
if dev is None: |
|
|
|
|
dev = get_torch_device() |
|
|
|
|
|
|
|
|
|
if hasattr(dev, 'type') and dev.type == 'cpu': |
|
|
|
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): |
|
|
|
|
mem_free_total = psutil.virtual_memory().available |
|
|
|
|
mem_free_torch = mem_free_total |
|
|
|
|
else: |
|
|
|
@ -224,8 +237,12 @@ def cpu_mode():
|
|
|
|
|
global vram_state |
|
|
|
|
return vram_state == CPU |
|
|
|
|
|
|
|
|
|
def mps_mode(): |
|
|
|
|
global vram_state |
|
|
|
|
return vram_state == MPS |
|
|
|
|
|
|
|
|
|
def should_use_fp16(): |
|
|
|
|
if cpu_mode(): |
|
|
|
|
if cpu_mode() or mps_mode(): |
|
|
|
|
return False #TODO ? |
|
|
|
|
|
|
|
|
|
if torch.cuda.is_bf16_supported(): |
|
|
|
|