Browse Source
* allow nodes to map over lists * make work with IS_CHANGED and VALIDATE_INPUTS * give list outputs distinct socket shape * add rebatch node * add batch index logic * add repeat latent batch * deal with noise mask edge cases in latentfrombatchpull/663/head
BlenderNeko
2 years ago
committed by
GitHub
6 changed files with 250 additions and 26 deletions
@ -0,0 +1,108 @@
|
||||
import torch |
||||
|
||||
class LatentRebatch: |
||||
@classmethod |
||||
def INPUT_TYPES(s): |
||||
return {"required": { "latents": ("LATENT",), |
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), |
||||
}} |
||||
RETURN_TYPES = ("LATENT",) |
||||
INPUT_IS_LIST = True |
||||
OUTPUT_IS_LIST = (True, ) |
||||
|
||||
FUNCTION = "rebatch" |
||||
|
||||
CATEGORY = "latent/batch" |
||||
|
||||
@staticmethod |
||||
def get_batch(latents, list_ind, offset): |
||||
'''prepare a batch out of the list of latents''' |
||||
samples = latents[list_ind]['samples'] |
||||
shape = samples.shape |
||||
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') |
||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: |
||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") |
||||
if mask.shape[0] < samples.shape[0]: |
||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] |
||||
if 'batch_index' in latents[list_ind]: |
||||
batch_inds = latents[list_ind]['batch_index'] |
||||
else: |
||||
batch_inds = [x+offset for x in range(shape[0])] |
||||
return samples, mask, batch_inds |
||||
|
||||
@staticmethod |
||||
def get_slices(indexable, num, batch_size): |
||||
'''divides an indexable object into num slices of length batch_size, and a remainder''' |
||||
slices = [] |
||||
for i in range(num): |
||||
slices.append(indexable[i*batch_size:(i+1)*batch_size]) |
||||
if num * batch_size < len(indexable): |
||||
return slices, indexable[num * batch_size:] |
||||
else: |
||||
return slices, None |
||||
|
||||
@staticmethod |
||||
def slice_batch(batch, num, batch_size): |
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] |
||||
return list(zip(*result)) |
||||
|
||||
@staticmethod |
||||
def cat_batch(batch1, batch2): |
||||
if batch1[0] is None: |
||||
return batch2 |
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] |
||||
return result |
||||
|
||||
def rebatch(self, latents, batch_size): |
||||
batch_size = batch_size[0] |
||||
|
||||
output_list = [] |
||||
current_batch = (None, None, None) |
||||
processed = 0 |
||||
|
||||
for i in range(len(latents)): |
||||
# fetch new entry of list |
||||
#samples, masks, indices = self.get_batch(latents, i) |
||||
next_batch = self.get_batch(latents, i, processed) |
||||
processed += len(next_batch[2]) |
||||
# set to current if current is None |
||||
if current_batch[0] is None: |
||||
current_batch = next_batch |
||||
# add previous to list if dimensions do not match |
||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: |
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size) |
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) |
||||
current_batch = next_batch |
||||
# cat if everything checks out |
||||
else: |
||||
current_batch = self.cat_batch(current_batch, next_batch) |
||||
|
||||
# add to list if dimensions gone above target batch size |
||||
if current_batch[0].shape[0] > batch_size: |
||||
num = current_batch[0].shape[0] // batch_size |
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size) |
||||
|
||||
for i in range(num): |
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) |
||||
|
||||
current_batch = remainder |
||||
|
||||
#add remainder |
||||
if current_batch[0] is not None: |
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size) |
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) |
||||
|
||||
#get rid of empty masks |
||||
for s in output_list: |
||||
if s['noise_mask'].mean() == 1.0: |
||||
del s['noise_mask'] |
||||
|
||||
return (output_list,) |
||||
|
||||
NODE_CLASS_MAPPINGS = { |
||||
"RebatchLatents": LatentRebatch, |
||||
} |
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = { |
||||
"RebatchLatents": "Rebatch Latents", |
||||
} |
Loading…
Reference in new issue