|
|
|
@ -99,10 +99,40 @@ class LatentRebatch:
|
|
|
|
|
|
|
|
|
|
return (output_list,) |
|
|
|
|
|
|
|
|
|
class ImageRebatch: |
|
|
|
|
@classmethod |
|
|
|
|
def INPUT_TYPES(s): |
|
|
|
|
return {"required": { "images": ("IMAGE",), |
|
|
|
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), |
|
|
|
|
}} |
|
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
|
|
INPUT_IS_LIST = True |
|
|
|
|
OUTPUT_IS_LIST = (True, ) |
|
|
|
|
|
|
|
|
|
FUNCTION = "rebatch" |
|
|
|
|
|
|
|
|
|
CATEGORY = "image/batch" |
|
|
|
|
|
|
|
|
|
def rebatch(self, images, batch_size): |
|
|
|
|
batch_size = batch_size[0] |
|
|
|
|
|
|
|
|
|
output_list = [] |
|
|
|
|
all_images = [] |
|
|
|
|
for img in images: |
|
|
|
|
for i in range(img.shape[0]): |
|
|
|
|
all_images.append(img[i:i+1]) |
|
|
|
|
|
|
|
|
|
for i in range(0, len(all_images), batch_size): |
|
|
|
|
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) |
|
|
|
|
|
|
|
|
|
return (output_list,) |
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
|
|
"RebatchLatents": LatentRebatch, |
|
|
|
|
"RebatchImages": ImageRebatch, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
|
|
"RebatchLatents": "Rebatch Latents", |
|
|
|
|
"RebatchImages": "Rebatch Images", |
|
|
|
|
} |