Browse Source

--gpu-only now keeps the VAE on the device.

pull/830/head
comfyanonymous 1 year ago
parent
commit
1c1b0e7299
  1. 9
      comfy/model_management.py
  2. 11
      comfy/sd.py

9
comfy/model_management.py

@ -349,6 +349,15 @@ def text_encoder_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def vae_device():
return get_torch_device()
def vae_offload_device():
if args.gpu_only or vram_state == VRAMState.SHARED:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type

11
comfy/sd.py

@ -605,8 +605,9 @@ class VAE:
self.first_stage_model.load_state_dict(sd, strict=False) self.first_stage_model.load_state_dict(sd, strict=False)
if device is None: if device is None:
device = model_management.get_torch_device() device = model_management.vae_device()
self.device = device self.device = device
self.offload_device = model_management.vae_offload_device()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
@ -651,7 +652,7 @@ class VAE:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples return pixel_samples
@ -659,7 +660,7 @@ class VAE:
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
@ -679,7 +680,7 @@ class VAE:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -687,7 +688,7 @@ class VAE:
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples return samples
def get_sd(self): def get_sd(self):

Loading…
Cancel
Save