|
|
|
@ -9,9 +9,58 @@ import math
|
|
|
|
|
from comfy import model_base |
|
|
|
|
import comfy.utils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) |
|
|
|
|
return abs(a*b) // math.gcd(a, b) |
|
|
|
|
|
|
|
|
|
class CONDRegular: |
|
|
|
|
def __init__(self, cond): |
|
|
|
|
self.cond = cond |
|
|
|
|
|
|
|
|
|
def can_concat(self, other): |
|
|
|
|
if self.cond.shape != other.cond.shape: |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def concat(self, others): |
|
|
|
|
conds = [self.cond] |
|
|
|
|
for x in others: |
|
|
|
|
conds.append(x.cond) |
|
|
|
|
return torch.cat(conds) |
|
|
|
|
|
|
|
|
|
class CONDCrossAttn: |
|
|
|
|
def __init__(self, cond): |
|
|
|
|
self.cond = cond |
|
|
|
|
|
|
|
|
|
def can_concat(self, other): |
|
|
|
|
s1 = self.cond.shape |
|
|
|
|
s2 = other.cond.shape |
|
|
|
|
if s1 != s2: |
|
|
|
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
mult_min = lcm(s1[1], s2[1]) |
|
|
|
|
diff = mult_min // min(s1[1], s2[1]) |
|
|
|
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def concat(self, others): |
|
|
|
|
conds = [self.cond] |
|
|
|
|
crossattn_max_len = self.cond.shape[1] |
|
|
|
|
for x in others: |
|
|
|
|
c = x.cond |
|
|
|
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) |
|
|
|
|
conds.append(c) |
|
|
|
|
|
|
|
|
|
out = [] |
|
|
|
|
for c in conds: |
|
|
|
|
if c.shape[1] < crossattn_max_len: |
|
|
|
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result |
|
|
|
|
out.append(c) |
|
|
|
|
return torch.cat(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#The main sampling function shared by all the samplers |
|
|
|
|
#Returns predicted noise |
|
|
|
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): |
|
|
|
@ -67,7 +116,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|
|
|
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
|
|
|
|
|
conditionning = {} |
|
|
|
|
conditionning['c_crossattn'] = cond[0] |
|
|
|
|
conditionning['c_crossattn'] = CONDCrossAttn(cond[0]) |
|
|
|
|
|
|
|
|
|
if 'concat' in cond[1]: |
|
|
|
|
cond_concat_in = cond[1]['concat'] |
|
|
|
@ -76,10 +125,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|
|
|
|
for x in cond_concat_in: |
|
|
|
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] |
|
|
|
|
cropped.append(cr) |
|
|
|
|
conditionning['c_concat'] = torch.cat(cropped, dim=1) |
|
|
|
|
conditionning['c_concat'] = CONDRegular(torch.cat(cropped, dim=1)) |
|
|
|
|
|
|
|
|
|
if adm_cond is not None: |
|
|
|
|
conditionning['c_adm'] = adm_cond |
|
|
|
|
conditionning['c_adm'] = CONDRegular(adm_cond) |
|
|
|
|
|
|
|
|
|
control = None |
|
|
|
|
if 'control' in cond[1]: |
|
|
|
@ -105,22 +154,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|
|
|
|
return True |
|
|
|
|
if c1.keys() != c2.keys(): |
|
|
|
|
return False |
|
|
|
|
if 'c_crossattn' in c1: |
|
|
|
|
s1 = c1['c_crossattn'].shape |
|
|
|
|
s2 = c2['c_crossattn'].shape |
|
|
|
|
if s1 != s2: |
|
|
|
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
mult_min = lcm(s1[1], s2[1]) |
|
|
|
|
diff = mult_min // min(s1[1], s2[1]) |
|
|
|
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much |
|
|
|
|
return False |
|
|
|
|
if 'c_concat' in c1: |
|
|
|
|
if c1['c_concat'].shape != c2['c_concat'].shape: |
|
|
|
|
return False |
|
|
|
|
if 'c_adm' in c1: |
|
|
|
|
if c1['c_adm'].shape != c2['c_adm'].shape: |
|
|
|
|
for k in c1: |
|
|
|
|
if not c1[k].can_concat(c2[k]): |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
@ -149,31 +184,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|
|
|
|
c_concat = [] |
|
|
|
|
c_adm = [] |
|
|
|
|
crossattn_max_len = 0 |
|
|
|
|
|
|
|
|
|
temp = {} |
|
|
|
|
for x in c_list: |
|
|
|
|
if 'c_crossattn' in x: |
|
|
|
|
c = x['c_crossattn'] |
|
|
|
|
if crossattn_max_len == 0: |
|
|
|
|
crossattn_max_len = c.shape[1] |
|
|
|
|
else: |
|
|
|
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) |
|
|
|
|
c_crossattn.append(c) |
|
|
|
|
if 'c_concat' in x: |
|
|
|
|
c_concat.append(x['c_concat']) |
|
|
|
|
if 'c_adm' in x: |
|
|
|
|
c_adm.append(x['c_adm']) |
|
|
|
|
for k in x: |
|
|
|
|
cur = temp.get(k, []) |
|
|
|
|
cur.append(x[k]) |
|
|
|
|
temp[k] = cur |
|
|
|
|
|
|
|
|
|
out = {} |
|
|
|
|
c_crossattn_out = [] |
|
|
|
|
for c in c_crossattn: |
|
|
|
|
if c.shape[1] < crossattn_max_len: |
|
|
|
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result |
|
|
|
|
c_crossattn_out.append(c) |
|
|
|
|
|
|
|
|
|
if len(c_crossattn_out) > 0: |
|
|
|
|
out['c_crossattn'] = torch.cat(c_crossattn_out) |
|
|
|
|
if len(c_concat) > 0: |
|
|
|
|
out['c_concat'] = torch.cat(c_concat) |
|
|
|
|
if len(c_adm) > 0: |
|
|
|
|
out['c_adm'] = torch.cat(c_adm) |
|
|
|
|
for k in temp: |
|
|
|
|
conds = temp[k] |
|
|
|
|
out[k] = conds[0].concat(conds[1:]) |
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): |
|
|
|
|