|
|
@ -158,7 +158,7 @@ class SplitImageWithAlpha: |
|
|
|
def split_image_with_alpha(self, image: torch.Tensor): |
|
|
|
def split_image_with_alpha(self, image: torch.Tensor): |
|
|
|
out_images = [i[:,:,:3] for i in image] |
|
|
|
out_images = [i[:,:,:3] for i in image] |
|
|
|
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] |
|
|
|
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] |
|
|
|
result = (torch.stack(out_images), torch.stack(out_alphas)) |
|
|
|
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) |
|
|
|
return result |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -180,7 +180,7 @@ class JoinImageWithAlpha: |
|
|
|
batch_size = min(len(image), len(alpha)) |
|
|
|
batch_size = min(len(image), len(alpha)) |
|
|
|
out_images = [] |
|
|
|
out_images = [] |
|
|
|
|
|
|
|
|
|
|
|
alpha = resize_mask(alpha, image.shape[1:]) |
|
|
|
alpha = 1.0 - resize_mask(alpha, image.shape[1:]) |
|
|
|
for i in range(batch_size): |
|
|
|
for i in range(batch_size): |
|
|
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) |
|
|
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) |
|
|
|
|
|
|
|
|
|
|
|