|
|
|
@ -474,15 +474,11 @@ class StableCascade_B(BaseModel):
|
|
|
|
|
|
|
|
|
|
clip_text_pooled = kwargs["pooled_output"] |
|
|
|
|
if clip_text_pooled is not None: |
|
|
|
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) |
|
|
|
|
out['clip'] = comfy.conds.CONDRegular(clip_text_pooled) |
|
|
|
|
|
|
|
|
|
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched |
|
|
|
|
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) |
|
|
|
|
|
|
|
|
|
out["effnet"] = comfy.conds.CONDRegular(prior) |
|
|
|
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) |
|
|
|
|
|
|
|
|
|
cross_attn = kwargs.get("cross_attn", None) |
|
|
|
|
if cross_attn is not None: |
|
|
|
|
out['clip'] = comfy.conds.CONDCrossAttn(cross_attn) |
|
|
|
|
return out |
|
|
|
|