|
|
|
@ -6,6 +6,10 @@ import contextlib
|
|
|
|
|
from comfy import model_management |
|
|
|
|
from .ldm.models.diffusion.ddim import DDIMSampler |
|
|
|
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps |
|
|
|
|
import math |
|
|
|
|
|
|
|
|
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) |
|
|
|
|
return abs(a*b) // math.gcd(a, b) |
|
|
|
|
|
|
|
|
|
#The main sampling function shared by all the samplers |
|
|
|
|
#Returns predicted noise |
|
|
|
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|
|
|
|
if c1.keys() != c2.keys(): |
|
|
|
|
return False |
|
|
|
|
if 'c_crossattn' in c1: |
|
|
|
|
if c1['c_crossattn'].shape != c2['c_crossattn'].shape: |
|
|
|
|
return False |
|
|
|
|
s1 = c1['c_crossattn'].shape |
|
|
|
|
s2 = c2['c_crossattn'].shape |
|
|
|
|
if s1 != s2: |
|
|
|
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
mult_min = lcm(s1[1], s2[1]) |
|
|
|
|
diff = mult_min // min(s1[1], s2[1]) |
|
|
|
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much |
|
|
|
|
return False |
|
|
|
|
if 'c_concat' in c1: |
|
|
|
|
if c1['c_concat'].shape != c2['c_concat'].shape: |
|
|
|
|
return False |
|
|
|
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|
|
|
|
c_crossattn = [] |
|
|
|
|
c_concat = [] |
|
|
|
|
c_adm = [] |
|
|
|
|
crossattn_max_len = 0 |
|
|
|
|
for x in c_list: |
|
|
|
|
if 'c_crossattn' in x: |
|
|
|
|
c_crossattn.append(x['c_crossattn']) |
|
|
|
|
c = x['c_crossattn'] |
|
|
|
|
if crossattn_max_len == 0: |
|
|
|
|
crossattn_max_len = c.shape[1] |
|
|
|
|
else: |
|
|
|
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) |
|
|
|
|
c_crossattn.append(c) |
|
|
|
|
if 'c_concat' in x: |
|
|
|
|
c_concat.append(x['c_concat']) |
|
|
|
|
if 'c_adm' in x: |
|
|
|
|
c_adm.append(x['c_adm']) |
|
|
|
|
out = {} |
|
|
|
|
if len(c_crossattn) > 0: |
|
|
|
|
out['c_crossattn'] = [torch.cat(c_crossattn)] |
|
|
|
|
c_crossattn_out = [] |
|
|
|
|
for c in c_crossattn: |
|
|
|
|
if c.shape[1] < crossattn_max_len: |
|
|
|
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result |
|
|
|
|
c_crossattn_out.append(c) |
|
|
|
|
|
|
|
|
|
if len(c_crossattn_out) > 0: |
|
|
|
|
out['c_crossattn'] = [torch.cat(c_crossattn_out)] |
|
|
|
|
if len(c_concat) > 0: |
|
|
|
|
out['c_concat'] = [torch.cat(c_concat)] |
|
|
|
|
if len(c_adm) > 0: |
|
|
|
|