Browse Source

Sampling code refactor to make it easier to add more conds.

pull/1836/head
comfyanonymous 1 year ago
parent
commit
3fce8881ca
  1. 107
      comfy/samplers.py

107
comfy/samplers.py

@ -9,9 +9,58 @@ import math
from comfy import model_base
import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class CONDCrossAttn:
def __init__(self, cond):
self.cond = cond
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
@ -67,7 +116,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = cond[0]
conditionning['c_crossattn'] = CONDCrossAttn(cond[0])
if 'concat' in cond[1]:
cond_concat_in = cond[1]['concat']
@ -76,10 +125,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
conditionning['c_concat'] = CONDRegular(torch.cat(cropped, dim=1))
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
conditionning['c_adm'] = CONDRegular(adm_cond)
control = None
if 'control' in cond[1]:
@ -105,22 +154,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return True
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
s1 = c1['c_crossattn'].shape
s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
@ -149,31 +184,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
if 'c_crossattn' in x:
c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
c_crossattn_out = []
for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = torch.cat(c_crossattn_out)
if len(c_concat) > 0:
out['c_concat'] = torch.cat(c_concat)
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):

Loading…
Cancel
Save