|
|
|
@ -102,8 +102,39 @@ class StableCascade_StageB_Conditioning:
|
|
|
|
|
c.append(n) |
|
|
|
|
return (c, ) |
|
|
|
|
|
|
|
|
|
class StableCascade_SuperResolutionControlnet: |
|
|
|
|
def __init__(self, device="cpu"): |
|
|
|
|
self.device = device |
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
|
def INPUT_TYPES(s): |
|
|
|
|
return {"required": { |
|
|
|
|
"image": ("IMAGE",), |
|
|
|
|
"vae": ("VAE", ), |
|
|
|
|
}} |
|
|
|
|
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") |
|
|
|
|
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") |
|
|
|
|
FUNCTION = "generate" |
|
|
|
|
|
|
|
|
|
CATEGORY = "_for_testing/stable_cascade" |
|
|
|
|
|
|
|
|
|
def generate(self, image, vae): |
|
|
|
|
width = image.shape[-2] |
|
|
|
|
height = image.shape[-3] |
|
|
|
|
batch_size = image.shape[0] |
|
|
|
|
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) |
|
|
|
|
|
|
|
|
|
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) |
|
|
|
|
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) |
|
|
|
|
return (controlnet_input, { |
|
|
|
|
"samples": c_latent, |
|
|
|
|
}, { |
|
|
|
|
"samples": b_latent, |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
|
|
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, |
|
|
|
|
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, |
|
|
|
|
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, |
|
|
|
|
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, |
|
|
|
|
} |
|
|
|
|