|
|
|
@ -138,11 +138,13 @@ class ControlBase:
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
class ControlNet(ControlBase): |
|
|
|
|
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): |
|
|
|
|
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): |
|
|
|
|
super().__init__(device) |
|
|
|
|
self.control_model = control_model |
|
|
|
|
self.load_device = load_device |
|
|
|
|
if control_model is not None: |
|
|
|
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) |
|
|
|
|
|
|
|
|
|
self.global_average_pooling = global_average_pooling |
|
|
|
|
self.model_sampling_current = None |
|
|
|
|
self.manual_cast_dtype = manual_cast_dtype |
|
|
|
@ -183,7 +185,9 @@ class ControlNet(ControlBase):
|
|
|
|
|
return self.control_merge(None, control, control_prev, output_dtype) |
|
|
|
|
|
|
|
|
|
def copy(self): |
|
|
|
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) |
|
|
|
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) |
|
|
|
|
c.control_model = self.control_model |
|
|
|
|
c.control_model_wrapped = self.control_model_wrapped |
|
|
|
|
self.copy_to(c) |
|
|
|
|
return c |
|
|
|
|
|
|
|
|
|