|
|
@ -2,6 +2,7 @@ from .k_diffusion import sampling as k_diffusion_sampling |
|
|
|
from .extra_samplers import uni_pc |
|
|
|
from .extra_samplers import uni_pc |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
import enum |
|
|
|
import enum |
|
|
|
|
|
|
|
import collections |
|
|
|
from comfy import model_management |
|
|
|
from comfy import model_management |
|
|
|
import math |
|
|
|
import math |
|
|
|
from comfy import model_base |
|
|
|
from comfy import model_base |
|
|
@ -61,9 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in): |
|
|
|
for c in model_conds: |
|
|
|
for c in model_conds: |
|
|
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) |
|
|
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) |
|
|
|
|
|
|
|
|
|
|
|
control = None |
|
|
|
control = conds.get('control', None) |
|
|
|
if 'control' in conds: |
|
|
|
|
|
|
|
control = conds['control'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches = None |
|
|
|
patches = None |
|
|
|
if 'gligen' in conds: |
|
|
|
if 'gligen' in conds: |
|
|
@ -78,7 +77,8 @@ def get_area_and_mult(conds, x_in, timestep_in): |
|
|
|
|
|
|
|
|
|
|
|
patches['middle_patch'] = [gligen_patch] |
|
|
|
patches['middle_patch'] = [gligen_patch] |
|
|
|
|
|
|
|
|
|
|
|
return (input_x, mult, conditioning, area, control, patches) |
|
|
|
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) |
|
|
|
|
|
|
|
return cond_obj(input_x, mult, conditioning, area, control, patches) |
|
|
|
|
|
|
|
|
|
|
|
def cond_equal_size(c1, c2): |
|
|
|
def cond_equal_size(c1, c2): |
|
|
|
if c1 is c2: |
|
|
|
if c1 is c2: |
|
|
@ -91,24 +91,24 @@ def cond_equal_size(c1, c2): |
|
|
|
return True |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def can_concat_cond(c1, c2): |
|
|
|
def can_concat_cond(c1, c2): |
|
|
|
if c1[0].shape != c2[0].shape: |
|
|
|
if c1.input_x.shape != c2.input_x.shape: |
|
|
|
return False |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
#control |
|
|
|
def objects_concatable(obj1, obj2): |
|
|
|
if (c1[4] is None) != (c2[4] is None): |
|
|
|
if (obj1 is None) != (obj2 is None): |
|
|
|
return False |
|
|
|
|
|
|
|
if c1[4] is not None: |
|
|
|
|
|
|
|
if c1[4] is not c2[4]: |
|
|
|
|
|
|
|
return False |
|
|
|
return False |
|
|
|
|
|
|
|
if obj1 is not None: |
|
|
|
|
|
|
|
if obj1 is not obj2: |
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
#patches |
|
|
|
if not objects_concatable(c1.control, c2.control): |
|
|
|
if (c1[5] is None) != (c2[5] is None): |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not objects_concatable(c1.patches, c2.patches): |
|
|
|
return False |
|
|
|
return False |
|
|
|
if (c1[5] is not None): |
|
|
|
|
|
|
|
if c1[5] is not c2[5]: |
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return cond_equal_size(c1[2], c2[2]) |
|
|
|
return cond_equal_size(c1.conditioning, c2.conditioning) |
|
|
|
|
|
|
|
|
|
|
|
def cond_cat(c_list): |
|
|
|
def cond_cat(c_list): |
|
|
|
c_crossattn = [] |
|
|
|
c_crossattn = [] |
|
|
@ -184,13 +184,13 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): |
|
|
|
for x in to_batch: |
|
|
|
for x in to_batch: |
|
|
|
o = to_run.pop(x) |
|
|
|
o = to_run.pop(x) |
|
|
|
p = o[0] |
|
|
|
p = o[0] |
|
|
|
input_x += [p[0]] |
|
|
|
input_x.append(p.input_x) |
|
|
|
mult += [p[1]] |
|
|
|
mult.append(p.mult) |
|
|
|
c += [p[2]] |
|
|
|
c.append(p.conditioning) |
|
|
|
area += [p[3]] |
|
|
|
area.append(p.area) |
|
|
|
cond_or_uncond += [o[1]] |
|
|
|
cond_or_uncond.append(o[1]) |
|
|
|
control = p[4] |
|
|
|
control = p.control |
|
|
|
patches = p[5] |
|
|
|
patches = p.patches |
|
|
|
|
|
|
|
|
|
|
|
batch_chunks = len(cond_or_uncond) |
|
|
|
batch_chunks = len(cond_or_uncond) |
|
|
|
input_x = torch.cat(input_x) |
|
|
|
input_x = torch.cat(input_x) |
|
|
|