diff --git a/comfy/samplers.py b/comfy/samplers.py index 0b38fbd1..f88b790d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,9 +9,58 @@ import math from comfy import model_base import comfy.utils + def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) +class CONDRegular: + def __init__(self, cond): + self.cond = cond + + def can_concat(self, other): + if self.cond.shape != other.cond.shape: + return False + return True + + def concat(self, others): + conds = [self.cond] + for x in others: + conds.append(x.cond) + return torch.cat(conds) + +class CONDCrossAttn: + def __init__(self, cond): + self.cond = cond + + def can_concat(self, other): + s1 = self.cond.shape + s2 = other.cond.shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False + return True + + def concat(self, others): + conds = [self.cond] + crossattn_max_len = self.cond.shape[1] + for x in others: + c = x.cond + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + conds.append(c) + + out = [] + for c in conds: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + out.append(c) + return torch.cat(out) + + #The main sampling function shared by all the samplers #Returns predicted noise def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): @@ -67,7 +116,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} - conditionning['c_crossattn'] = cond[0] + conditionning['c_crossattn'] = CONDCrossAttn(cond[0]) if 'concat' in cond[1]: cond_concat_in = cond[1]['concat'] @@ -76,10 +125,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod for x in cond_concat_in: cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) + conditionning['c_concat'] = CONDRegular(torch.cat(cropped, dim=1)) if adm_cond is not None: - conditionning['c_adm'] = adm_cond + conditionning['c_adm'] = CONDRegular(adm_cond) control = None if 'control' in cond[1]: @@ -105,22 +154,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod return True if c1.keys() != c2.keys(): return False - if 'c_crossattn' in c1: - s1 = c1['c_crossattn'].shape - s2 = c2['c_crossattn'].shape - if s1 != s2: - if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen - return False - - mult_min = lcm(s1[1], s2[1]) - diff = mult_min // min(s1[1], s2[1]) - if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much - return False - if 'c_concat' in c1: - if c1['c_concat'].shape != c2['c_concat'].shape: - return False - if 'c_adm' in c1: - if c1['c_adm'].shape != c2['c_adm'].shape: + for k in c1: + if not c1[k].can_concat(c2[k]): return False return True @@ -149,31 +184,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod c_concat = [] c_adm = [] crossattn_max_len = 0 + + temp = {} for x in c_list: - if 'c_crossattn' in x: - c = x['c_crossattn'] - if crossattn_max_len == 0: - crossattn_max_len = c.shape[1] - else: - crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) - c_crossattn.append(c) - if 'c_concat' in x: - c_concat.append(x['c_concat']) - if 'c_adm' in x: - c_adm.append(x['c_adm']) + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + out = {} - c_crossattn_out = [] - for c in c_crossattn: - if c.shape[1] < crossattn_max_len: - c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result - c_crossattn_out.append(c) - - if len(c_crossattn_out) > 0: - out['c_crossattn'] = torch.cat(c_crossattn_out) - if len(c_concat) > 0: - out['c_concat'] = torch.cat(c_concat) - if len(c_adm) > 0: - out['c_adm'] = torch.cat(c_adm) + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):