You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
2.0 KiB
65 lines
2.0 KiB
1 year ago
|
import enum
|
||
|
import torch
|
||
|
import math
|
||
|
import comfy.utils
|
||
|
|
||
|
|
||
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||
|
return abs(a*b) // math.gcd(a, b)
|
||
|
|
||
|
class CONDRegular:
|
||
|
def __init__(self, cond):
|
||
|
self.cond = cond
|
||
|
|
||
|
def _copy_with(self, cond):
|
||
|
return self.__class__(cond)
|
||
|
|
||
|
def process_cond(self, batch_size, device, **kwargs):
|
||
|
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||
|
|
||
|
def can_concat(self, other):
|
||
|
if self.cond.shape != other.cond.shape:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def concat(self, others):
|
||
|
conds = [self.cond]
|
||
|
for x in others:
|
||
|
conds.append(x.cond)
|
||
|
return torch.cat(conds)
|
||
|
|
||
|
class CONDNoiseShape(CONDRegular):
|
||
|
def process_cond(self, batch_size, device, area, **kwargs):
|
||
|
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||
|
|
||
|
|
||
|
class CONDCrossAttn(CONDRegular):
|
||
|
def can_concat(self, other):
|
||
|
s1 = self.cond.shape
|
||
|
s2 = other.cond.shape
|
||
|
if s1 != s2:
|
||
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||
|
return False
|
||
|
|
||
|
mult_min = lcm(s1[1], s2[1])
|
||
|
diff = mult_min // min(s1[1], s2[1])
|
||
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def concat(self, others):
|
||
|
conds = [self.cond]
|
||
|
crossattn_max_len = self.cond.shape[1]
|
||
|
for x in others:
|
||
|
c = x.cond
|
||
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||
|
conds.append(c)
|
||
|
|
||
|
out = []
|
||
|
for c in conds:
|
||
|
if c.shape[1] < crossattn_max_len:
|
||
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||
|
out.append(c)
|
||
|
return torch.cat(out)
|