|
|
|
@ -3,9 +3,7 @@ import torch
|
|
|
|
|
|
|
|
|
|
def reshape_latent_to(target_shape, latent): |
|
|
|
|
if latent.shape[1:] != target_shape[1:]: |
|
|
|
|
latent.movedim(1, -1) |
|
|
|
|
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") |
|
|
|
|
latent.movedim(-1, 1) |
|
|
|
|
return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -102,9 +100,32 @@ class LatentInterpolate:
|
|
|
|
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) |
|
|
|
|
return (samples_out,) |
|
|
|
|
|
|
|
|
|
class LatentBatch: |
|
|
|
|
@classmethod |
|
|
|
|
def INPUT_TYPES(s): |
|
|
|
|
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} |
|
|
|
|
|
|
|
|
|
RETURN_TYPES = ("LATENT",) |
|
|
|
|
FUNCTION = "batch" |
|
|
|
|
|
|
|
|
|
CATEGORY = "latent/batch" |
|
|
|
|
|
|
|
|
|
def batch(self, samples1, samples2): |
|
|
|
|
samples_out = samples1.copy() |
|
|
|
|
s1 = samples1["samples"] |
|
|
|
|
s2 = samples2["samples"] |
|
|
|
|
|
|
|
|
|
if s1.shape[1:] != s2.shape[1:]: |
|
|
|
|
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") |
|
|
|
|
s = torch.cat((s1, s2), dim=0) |
|
|
|
|
samples_out["samples"] = s |
|
|
|
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) |
|
|
|
|
return (samples_out,) |
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
|
|
"LatentAdd": LatentAdd, |
|
|
|
|
"LatentSubtract": LatentSubtract, |
|
|
|
|
"LatentMultiply": LatentMultiply, |
|
|
|
|
"LatentInterpolate": LatentInterpolate, |
|
|
|
|
"LatentBatch": LatentBatch, |
|
|
|
|
} |
|
|
|
|