|
|
@ -359,6 +359,28 @@ class VAE: |
|
|
|
samples = samples.cpu() |
|
|
|
samples = samples.cpu() |
|
|
|
return samples |
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_image_to(tensor, target_latent_tensor, batched_number): |
|
|
|
|
|
|
|
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") |
|
|
|
|
|
|
|
target_batch_size = target_latent_tensor.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_batch_size = tensor.shape[0] |
|
|
|
|
|
|
|
print(current_batch_size, target_batch_size) |
|
|
|
|
|
|
|
if current_batch_size == 1: |
|
|
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
per_batch = target_batch_size // batched_number |
|
|
|
|
|
|
|
tensor = tensor[:per_batch] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if per_batch > tensor.shape[0]: |
|
|
|
|
|
|
|
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_batch_size = tensor.shape[0] |
|
|
|
|
|
|
|
if current_batch_size == target_batch_size: |
|
|
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return torch.cat([tensor] * batched_number, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
class ControlNet: |
|
|
|
class ControlNet: |
|
|
|
def __init__(self, control_model, device="cuda"): |
|
|
|
def __init__(self, control_model, device="cuda"): |
|
|
|
self.control_model = control_model |
|
|
|
self.control_model = control_model |
|
|
@ -368,7 +390,7 @@ class ControlNet: |
|
|
|
self.device = device |
|
|
|
self.device = device |
|
|
|
self.previous_controlnet = None |
|
|
|
self.previous_controlnet = None |
|
|
|
|
|
|
|
|
|
|
|
def get_control(self, x_noisy, t, cond_txt): |
|
|
|
def get_control(self, x_noisy, t, cond_txt, batched_number): |
|
|
|
control_prev = None |
|
|
|
control_prev = None |
|
|
|
if self.previous_controlnet is not None: |
|
|
|
if self.previous_controlnet is not None: |
|
|
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) |
|
|
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) |
|
|
@ -378,7 +400,7 @@ class ControlNet: |
|
|
|
if self.cond_hint is not None: |
|
|
|
if self.cond_hint is not None: |
|
|
|
del self.cond_hint |
|
|
|
del self.cond_hint |
|
|
|
self.cond_hint = None |
|
|
|
self.cond_hint = None |
|
|
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) |
|
|
|
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
if self.control_model.dtype == torch.float16: |
|
|
|
if self.control_model.dtype == torch.float16: |
|
|
|
precision_scope = torch.autocast |
|
|
|
precision_scope = torch.autocast |
|
|
@ -516,7 +538,7 @@ class T2IAdapter: |
|
|
|
self.cond_hint_original = None |
|
|
|
self.cond_hint_original = None |
|
|
|
self.cond_hint = None |
|
|
|
self.cond_hint = None |
|
|
|
|
|
|
|
|
|
|
|
def get_control(self, x_noisy, t, cond_txt): |
|
|
|
def get_control(self, x_noisy, t, cond_txt, batched_number): |
|
|
|
control_prev = None |
|
|
|
control_prev = None |
|
|
|
if self.previous_controlnet is not None: |
|
|
|
if self.previous_controlnet is not None: |
|
|
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) |
|
|
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) |
|
|
@ -525,7 +547,7 @@ class T2IAdapter: |
|
|
|
if self.cond_hint is not None: |
|
|
|
if self.cond_hint is not None: |
|
|
|
del self.cond_hint |
|
|
|
del self.cond_hint |
|
|
|
self.cond_hint = None |
|
|
|
self.cond_hint = None |
|
|
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) |
|
|
|
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) |
|
|
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1: |
|
|
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1: |
|
|
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) |
|
|
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) |
|
|
|
self.t2i_model.to(self.device) |
|
|
|
self.t2i_model.to(self.device) |
|
|
|