|
|
|
@ -334,8 +334,13 @@ class ControlNet:
|
|
|
|
|
self.cond_hint = None |
|
|
|
|
self.strength = 1.0 |
|
|
|
|
self.device = device |
|
|
|
|
self.previous_controlnet = None |
|
|
|
|
|
|
|
|
|
def get_control(self, x_noisy, t, cond_txt): |
|
|
|
|
control_prev = None |
|
|
|
|
if self.previous_controlnet is not None: |
|
|
|
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) |
|
|
|
|
|
|
|
|
|
output_dtype = x_noisy.dtype |
|
|
|
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: |
|
|
|
|
if self.cond_hint is not None: |
|
|
|
@ -354,10 +359,15 @@ class ControlNet:
|
|
|
|
|
self.control_model = model_management.unload_if_low_vram(self.control_model) |
|
|
|
|
out = [] |
|
|
|
|
autocast_enabled = torch.is_autocast_enabled() |
|
|
|
|
for x in control: |
|
|
|
|
|
|
|
|
|
for i in range(len(control)): |
|
|
|
|
x = control[i] |
|
|
|
|
x *= self.strength |
|
|
|
|
if x.dtype != output_dtype and not autocast_enabled: |
|
|
|
|
x = x.to(output_dtype) |
|
|
|
|
|
|
|
|
|
if control_prev is not None: |
|
|
|
|
x += control_prev[i] |
|
|
|
|
out.append(x) |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
@ -366,7 +376,13 @@ class ControlNet:
|
|
|
|
|
self.strength = strength |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
def set_previous_controlnet(self, controlnet): |
|
|
|
|
self.previous_controlnet = controlnet |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
def cleanup(self): |
|
|
|
|
if self.previous_controlnet is not None: |
|
|
|
|
self.previous_controlnet.cleanup() |
|
|
|
|
if self.cond_hint is not None: |
|
|
|
|
del self.cond_hint |
|
|
|
|
self.cond_hint = None |
|
|
|
@ -377,6 +393,13 @@ class ControlNet:
|
|
|
|
|
c.strength = self.strength |
|
|
|
|
return c |
|
|
|
|
|
|
|
|
|
def get_control_models(self): |
|
|
|
|
out = [] |
|
|
|
|
if self.previous_controlnet is not None: |
|
|
|
|
out += self.previous_controlnet.get_control_models() |
|
|
|
|
out.append(self.control_model) |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def load_controlnet(ckpt_path): |
|
|
|
|
controlnet_data = load_torch_file(ckpt_path) |
|
|
|
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' |
|
|
|
|