Browse Source

Improve code legibility.

pull/1943/merge
comfyanonymous 11 months ago
parent
commit
329c571993
  1. 42
      comfy/samplers.py

42
comfy/samplers.py

@ -2,6 +2,7 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import enum import enum
import collections
from comfy import model_management from comfy import model_management
import math import math
from comfy import model_base from comfy import model_base
@ -61,9 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
for c in model_conds: for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None control = conds.get('control', None)
if 'control' in conds:
control = conds['control']
patches = None patches = None
if 'gligen' in conds: if 'gligen' in conds:
@ -78,7 +77,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
patches['middle_patch'] = [gligen_patch] patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditioning, area, control, patches) cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
@ -91,24 +91,24 @@ def cond_equal_size(c1, c2):
return True return True
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape: if c1.input_x.shape != c2.input_x.shape:
return False return False
#control def objects_concatable(obj1, obj2):
if (c1[4] is None) != (c2[4] is None): if (obj1 is None) != (obj2 is None):
return False return False
if c1[4] is not None: if obj1 is not None:
if c1[4] is not c2[4]: if obj1 is not obj2:
return False return False
return True
#patches if not objects_concatable(c1.control, c2.control):
if (c1[5] is None) != (c2[5] is None):
return False return False
if (c1[5] is not None):
if c1[5] is not c2[5]: if not objects_concatable(c1.patches, c2.patches):
return False return False
return cond_equal_size(c1[2], c2[2]) return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = [] c_crossattn = []
@ -184,13 +184,13 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
for x in to_batch: for x in to_batch:
o = to_run.pop(x) o = to_run.pop(x)
p = o[0] p = o[0]
input_x += [p[0]] input_x.append(p.input_x)
mult += [p[1]] mult.append(p.mult)
c += [p[2]] c.append(p.conditioning)
area += [p[3]] area.append(p.area)
cond_or_uncond += [o[1]] cond_or_uncond.append(o[1])
control = p[4] control = p.control
patches = p[5] patches = p.patches
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)

Loading…
Cancel
Save