From 9ac0b487acf569ebe8a2d87ed750fed58b59262d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Dec 2023 02:35:45 -0500 Subject: [PATCH] Make --gpu-only put intermediate values in GPU memory instead of cpu. --- comfy/clip_vision.py | 4 ++-- comfy/model_management.py | 6 ++++++ comfy/sample.py | 4 ++-- comfy/sd.py | 23 ++++++++++++----------- comfy/sd1_clip.py | 6 +++--- comfy/utils.py | 12 ++++++------ comfy_extras/nodes_canny.py | 2 +- comfy_extras/nodes_post_processing.py | 2 +- nodes.py | 6 +++--- 9 files changed, 36 insertions(+), 29 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 9e2e03d7..449be8e4 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -54,10 +54,10 @@ class ClipVisionModel(): t = outputs[k] if t is not None: if k == 'hidden_states': - outputs["penultimate_hidden_states"] = t[-2].cpu() + outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device()) outputs["hidden_states"] = None else: - outputs[k] = t.cpu() + outputs[k] = t.to(comfy.model_management.intermediate_device()) return outputs diff --git a/comfy/model_management.py b/comfy/model_management.py index 3588d350..ef9bec54 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -508,6 +508,12 @@ def text_encoder_dtype(device=None): else: return torch.float32 +def intermediate_device(): + if args.gpu_only: + return get_torch_device() + else: + return torch.device("cpu") + def vae_device(): return get_torch_device() diff --git a/comfy/sample.py b/comfy/sample.py index bcbed334..eadd6dcc 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -98,7 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) @@ -111,7 +111,7 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent sigmas = sigmas.to(model.load_device) samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples diff --git a/comfy/sd.py b/comfy/sd.py index f4f84d0a..43e201d3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -190,6 +190,7 @@ class VAE: offload_device = model_management.vae_offload_device() self.vae_dtype = model_management.vae_dtype() self.first_stage_model.to(self.vae_dtype) + self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) @@ -201,9 +202,9 @@ class VAE: decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output @@ -214,9 +215,9 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) samples /= 3.0 return samples @@ -228,15 +229,15 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - pixel_samples = pixel_samples.cpu().movedim(1,-1) + pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): @@ -252,10 +253,10 @@ class VAE: free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") + samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1acd972c..4530168a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -39,7 +39,7 @@ class ClipTokenWeightEncoder: out, pooled = self.encode(to_encode) if pooled is not None: - first_pooled = pooled[0:1].cpu() + first_pooled = pooled[0:1].to(model_management.intermediate_device()) else: first_pooled = pooled @@ -56,8 +56,8 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - return out[-1:].cpu(), first_pooled - return torch.cat(output, dim=-2).cpu(), first_pooled + return out[-1:].to(model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" diff --git a/comfy/utils.py b/comfy/utils.py index 50557704..f8026dda 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -376,7 +376,7 @@ def lanczos(samples, width, height): images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) - return result + return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -405,17 +405,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): - output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device) for b in range(samples.shape[0]): s = samples[b:b+1] - out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") - out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") + out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) + out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): s_in = s[:,:,y:y+tile_y,x:x+tile_x] - ps = function(s_in).cpu() + ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 94d453f2..730dded5 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -291,7 +291,7 @@ class Canny: def detect_edge(self, image, low_threshold, high_threshold): output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) + img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) return (img_out,) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 12704f54..71660f8a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -226,7 +226,7 @@ class Sharpen: batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) + kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) diff --git a/nodes.py b/nodes.py index 24e591fd..db96e0e2 100644 --- a/nodes.py +++ b/nodes.py @@ -947,8 +947,8 @@ class GLIGENTextBoxApply: return (c, ) class EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device + def __init__(self): + self.device = comfy.model_management.intermediate_device() @classmethod def INPUT_TYPES(s): @@ -961,7 +961,7 @@ class EmptyLatentImage: CATEGORY = "latent" def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) return ({"samples":latent}, )