You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.4 KiB
78 lines
2.4 KiB
import torch |
|
import math |
|
import comfy.utils |
|
|
|
|
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) |
|
return abs(a*b) // math.gcd(a, b) |
|
|
|
class CONDRegular: |
|
def __init__(self, cond): |
|
self.cond = cond |
|
|
|
def _copy_with(self, cond): |
|
return self.__class__(cond) |
|
|
|
def process_cond(self, batch_size, device, **kwargs): |
|
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) |
|
|
|
def can_concat(self, other): |
|
if self.cond.shape != other.cond.shape: |
|
return False |
|
return True |
|
|
|
def concat(self, others): |
|
conds = [self.cond] |
|
for x in others: |
|
conds.append(x.cond) |
|
return torch.cat(conds) |
|
|
|
class CONDNoiseShape(CONDRegular): |
|
def process_cond(self, batch_size, device, area, **kwargs): |
|
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] |
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) |
|
|
|
|
|
class CONDCrossAttn(CONDRegular): |
|
def can_concat(self, other): |
|
s1 = self.cond.shape |
|
s2 = other.cond.shape |
|
if s1 != s2: |
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen |
|
return False |
|
|
|
mult_min = lcm(s1[1], s2[1]) |
|
diff = mult_min // min(s1[1], s2[1]) |
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much |
|
return False |
|
return True |
|
|
|
def concat(self, others): |
|
conds = [self.cond] |
|
crossattn_max_len = self.cond.shape[1] |
|
for x in others: |
|
c = x.cond |
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) |
|
conds.append(c) |
|
|
|
out = [] |
|
for c in conds: |
|
if c.shape[1] < crossattn_max_len: |
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result |
|
out.append(c) |
|
return torch.cat(out) |
|
|
|
class CONDConstant(CONDRegular): |
|
def __init__(self, cond): |
|
self.cond = cond |
|
|
|
def process_cond(self, batch_size, device, **kwargs): |
|
return self._copy_with(self.cond) |
|
|
|
def can_concat(self, other): |
|
if self.cond != other.cond: |
|
return False |
|
return True |
|
|
|
def concat(self, others): |
|
return self.cond
|
|
|