Browse Source

Add a way to set different conditioning for the controlnet.

pull/2751/head
comfyanonymous 9 months ago
parent
commit
25a4805e51
  1. 2
      comfy/controlnet.py
  2. 4
      comfy/model_base.py

2
comfy/controlnet.py

@ -166,7 +166,7 @@ class ControlNet(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond['c_crossattn'] context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None) y = cond.get('y', None)
if y is not None: if y is not None:
y = y.to(dtype) y = y.to(dtype)

4
comfy/model_base.py

@ -153,6 +153,10 @@ class BaseModel(torch.nn.Module):
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
if cross_attn_cnet is not None:
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
return out return out
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):

Loading…
Cancel
Save