|
|
|
@ -460,42 +460,18 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
|
|
|
|
uncond[temp[1]] = [o[0], n] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_adm(conds, batch_size, device, noise_augmentor=None): |
|
|
|
|
def encode_adm(model, conds, batch_size, device): |
|
|
|
|
for t in range(len(conds)): |
|
|
|
|
x = conds[t] |
|
|
|
|
adm_out = None |
|
|
|
|
if noise_augmentor is not None: |
|
|
|
|
if 'adm' in x[1]: |
|
|
|
|
adm_inputs = [] |
|
|
|
|
weights = [] |
|
|
|
|
noise_aug = [] |
|
|
|
|
adm_in = x[1]["adm"] |
|
|
|
|
for adm_c in adm_in: |
|
|
|
|
adm_cond = adm_c[0].image_embeds |
|
|
|
|
weight = adm_c[1] |
|
|
|
|
noise_augment = adm_c[2] |
|
|
|
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) |
|
|
|
|
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) |
|
|
|
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight |
|
|
|
|
weights.append(weight) |
|
|
|
|
noise_aug.append(noise_augment) |
|
|
|
|
adm_inputs.append(adm_out) |
|
|
|
|
|
|
|
|
|
if len(noise_aug) > 1: |
|
|
|
|
adm_out = torch.stack(adm_inputs).sum(0) |
|
|
|
|
#TODO: add a way to control this |
|
|
|
|
noise_augment = 0.05 |
|
|
|
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) |
|
|
|
|
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) |
|
|
|
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) |
|
|
|
|
else: |
|
|
|
|
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) |
|
|
|
|
if 'adm' in x[1]: |
|
|
|
|
adm_out = x[1]["adm"] |
|
|
|
|
else: |
|
|
|
|
if 'adm' in x[1]: |
|
|
|
|
adm_out = x[1]["adm"].to(device) |
|
|
|
|
params = x[1].copy() |
|
|
|
|
adm_out = model.encode_adm(device=device, **params) |
|
|
|
|
if adm_out is not None: |
|
|
|
|
x[1] = x[1].copy() |
|
|
|
|
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) |
|
|
|
|
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) |
|
|
|
|
|
|
|
|
|
return conds |
|
|
|
|
|
|
|
|
@ -603,11 +579,8 @@ class KSampler:
|
|
|
|
|
precision_scope = contextlib.nullcontext |
|
|
|
|
|
|
|
|
|
if self.model.is_adm(): |
|
|
|
|
noise_augmentor = None |
|
|
|
|
if hasattr(self.model, 'noise_augmentor'): #unclip |
|
|
|
|
noise_augmentor = self.model.noise_augmentor |
|
|
|
|
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor) |
|
|
|
|
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor) |
|
|
|
|
positive = encode_adm(self.model, positive, noise.shape[0], self.device) |
|
|
|
|
negative = encode_adm(self.model, negative, noise.shape[0], self.device) |
|
|
|
|
|
|
|
|
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} |
|
|
|
|
|
|
|
|
|