@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
return torch . cat ( [ tensor ] * batched_number , dim = 0 )
class ControlNet :
def __init__ ( self , control_model , device = None ) :
def __init__ ( self , control_model , global_average_pooling = False , device = None ) :
self . control_model = control_model
self . cond_hint_original = None
self . cond_hint = None
@ -630,6 +630,7 @@ class ControlNet:
device = model_management . get_torch_device ( )
self . device = device
self . previous_controlnet = None
self . global_average_pooling = global_average_pooling
def get_control ( self , x_noisy , t , cond_txt , batched_number ) :
control_prev = None
@ -665,6 +666,9 @@ class ControlNet:
key = ' output '
index = i
x = control [ i ]
if self . global_average_pooling :
x = torch . mean ( x , dim = ( 2 , 3 ) , keepdim = True ) . repeat ( 1 , 1 , x . shape [ 2 ] , x . shape [ 3 ] )
x * = self . strength
if x . dtype != output_dtype and not autocast_enabled :
x = x . to ( output_dtype )
@ -695,7 +699,7 @@ class ControlNet:
self . cond_hint = None
def copy ( self ) :
c = ControlNet ( self . control_model )
c = ControlNet ( self . control_model , global_average_pooling = self . global_average_pooling )
c . cond_hint_original = self . cond_hint_original
c . strength = self . strength
return c
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
if use_fp16 :
control_model = control_model . half ( )
control = ControlNet ( control_model )
global_average_pooling = False
if ckpt_path . endswith ( " _shuffle.pth " ) or ckpt_path . endswith ( " _shuffle.safetensors " ) or ckpt_path . endswith ( " _shuffle_fp16.safetensors " ) : #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet ( control_model , global_average_pooling = global_average_pooling )
return control
class T2IAdapter :