diff --git a/comfy/sd.py b/comfy/sd.py index 9864ef0a..ac13d8bc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -489,21 +489,39 @@ def load_controlnet(ckpt_path, model=None): if model_management.should_use_fp16() and controlnet_data[key].dtype == torch.float16: use_fp16 = True - control_model = cldm.ControlNet(image_size=32, - in_channels=4, - hint_channels=3, - model_channels=320, - attention_resolutions=[ 4, 2, 1 ], - num_res_blocks=2, - channel_mult=[ 1, 2, 4, 4 ], - num_heads=8, - use_spatial_transformer=True, - transformer_depth=1, - context_dim=context_dim, - use_checkpoint=True, - legacy=False, - use_fp16=use_fp16) - + if context_dim == 768: + #SD1.x + control_model = cldm.ControlNet(image_size=32, + in_channels=4, + hint_channels=3, + model_channels=320, + attention_resolutions=[ 4, 2, 1 ], + num_res_blocks=2, + channel_mult=[ 1, 2, 4, 4 ], + num_heads=8, + use_spatial_transformer=True, + transformer_depth=1, + context_dim=context_dim, + use_checkpoint=True, + legacy=False, + use_fp16=use_fp16) + else: + #SD2.x + control_model = cldm.ControlNet(image_size=32, + in_channels=4, + hint_channels=3, + model_channels=320, + attention_resolutions=[ 4, 2, 1 ], + num_res_blocks=2, + channel_mult=[ 1, 2, 4, 4 ], + num_head_channels=64, + use_spatial_transformer=True, + use_linear_in_transformer=True, + transformer_depth=1, + context_dim=context_dim, + use_checkpoint=True, + legacy=False, + use_fp16=use_fp16) if pth: if 'difference' in controlnet_data: if model is not None: