|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
import torch |
|
|
|
|
import math |
|
|
|
|
import os |
|
|
|
|
import contextlib |
|
|
|
|
import comfy.utils |
|
|
|
|
import comfy.model_management |
|
|
|
|
import comfy.model_detection |
|
|
|
@ -147,24 +148,31 @@ class ControlNet(ControlBase):
|
|
|
|
|
else: |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
dtype = self.control_model.dtype |
|
|
|
|
if comfy.model_management.supports_dtype(self.device, dtype): |
|
|
|
|
precision_scope = lambda a: contextlib.nullcontext(a) |
|
|
|
|
else: |
|
|
|
|
precision_scope = torch.autocast |
|
|
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
output_dtype = x_noisy.dtype |
|
|
|
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: |
|
|
|
|
if self.cond_hint is not None: |
|
|
|
|
del self.cond_hint |
|
|
|
|
self.cond_hint = None |
|
|
|
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) |
|
|
|
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) |
|
|
|
|
if x_noisy.shape[0] != self.cond_hint.shape[0]: |
|
|
|
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context = cond['c_crossattn'] |
|
|
|
|
y = cond.get('y', None) |
|
|
|
|
if y is not None: |
|
|
|
|
y = y.to(self.control_model.dtype) |
|
|
|
|
y = y.to(dtype) |
|
|
|
|
timestep = self.model_sampling_current.timestep(t) |
|
|
|
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) |
|
|
|
|
|
|
|
|
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) |
|
|
|
|
with precision_scope(comfy.model_management.get_autocast_device(self.device)): |
|
|
|
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) |
|
|
|
|
return self.control_merge(None, control, control_prev, output_dtype) |
|
|
|
|
|
|
|
|
|
def copy(self): |
|
|
|
|