diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index e8563a35..49ce5ae3 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -66,7 +66,7 @@ class DiscreteSchedule(nn.Module): def sigma_to_t(self, sigma, quantize=None): quantize = self.quantize if quantize is None else quantize log_sigma = sigma.log() - dists = log_sigma - self.log_sigmas[:, None] + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] if quantize: return dists.abs().argmin(dim=0).view(sigma.shape) low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e098124..f7374aa1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,11 +1,48 @@ +CPU = 0 +NO_VRAM = 1 +LOW_VRAM = 2 +NORMAL_VRAM = 3 + +accelerate_enabled = False +vram_state = NORMAL_VRAM + +import sys + +set_vram_to = NORMAL_VRAM +if "--lowvram" in sys.argv: + set_vram_to = LOW_VRAM +if "--novram" in sys.argv: + set_vram_to = NO_VRAM + +if set_vram_to != NORMAL_VRAM: + try: + import accelerate + accelerate_enabled = True + vram_state = set_vram_to + except Exception as e: + import traceback + print(traceback.format_exc()) + print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + + +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state]) + current_loaded_model = None +model_accelerated = False + + def unload_model(): global current_loaded_model + global model_accelerated if current_loaded_model is not None: + if model_accelerated: + accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) + model_accelerated = False + current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None @@ -13,6 +50,9 @@ def unload_model(): def load_model_gpu(model): global current_loaded_model + global vram_state + global model_accelerated + if model is current_loaded_model: return unload_model() @@ -22,5 +62,16 @@ def load_model_gpu(model): model.unpatch_model() raise e current_loaded_model = model - real_model.cuda() + if vram_state == CPU: + pass + elif vram_state == NORMAL_VRAM: + model_accelerated = False + real_model.cuda() + else: + if vram_state == NO_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) + elif vram_state == LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "1GiB", "cpu": "16GiB"}) + accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") + model_accelerated = True return current_loaded_model diff --git a/main.py b/main.py index 4f7bb65b..0f466a3f 100644 --- a/main.py +++ b/main.py @@ -14,6 +14,9 @@ if __name__ == "__main__": print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") print() + print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.") + print("\t--novram\t\t\tWhen lowvram isn't enough.") + print() exit() if '--dont-upcast-attention' in sys.argv: diff --git a/requirements.txt b/requirements.txt index 64cc3fc2..cc59cf1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ transformers safetensors pytorch_lightning +accelerate +