Browse Source

Add support for GLIGEN textbox model.

pull/543/head
comfyanonymous 2 years ago
parent
commit
3696d1699a
  1. 343
      comfy/gligen.py
  2. 16
      comfy/ldm/modules/attention.py
  3. 2
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  4. 6
      comfy/model_management.py
  5. 57
      comfy/samplers.py
  6. 22
      comfy/sd.py
  7. 2
      folder_paths.py
  8. 0
      models/gligen/put_gligen_models_here
  9. 71
      nodes.py

343
comfy/gligen.py

@ -0,0 +1,343 @@
import torch
from torch import nn, einsum
from ldm.modules.attention import CrossAttention
from inspect import isfunction
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
N_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense2(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
B, N_visual, _ = x.shape
B, N_ground, _ = objs.shape
objs = self.linear(objs)
# sanity check
size_v = math.sqrt(N_visual)
size_g = math.sqrt(N_ground)
assert int(size_v) == size_v, "Visual tokens must be square rootable"
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
size_v = int(size_v)
size_g = int(size_g)
# select grounding token and resize it to visual token size as residual
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
:, N_visual:, :]
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
out = torch.nn.functional.interpolate(
out, (size_v, size_v), mode='bicubic')
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
# add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):
self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
@torch.no_grad()
def __call__(self, x, cat_dim=-1):
"x: arbitrary shape of tensor. dim: cat dim"
out = []
for freq in self.freq_bands:
out.append(torch.sin(freq * x))
out.append(torch.cos(freq * x))
return torch.cat(out, cat_dim)
class PositionNet(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
torch.zeros([self.in_dim]))
self.null_position_feature = torch.nn.Parameter(
torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
assert objs.shape == torch.Size([B, N, self.out_dim])
return objs
class Gligen(nn.Module):
def __init__(self, modules, position_net, key_dim):
super().__init__()
self.module_list = nn.ModuleList(modules)
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu")
boxes = []
positive_embeddings = []
for p in position_params:
x1 = (p[4]) / w
y1 = (p[3]) / h
x2 = (p[4] + p[2]) / w
y2 = (p[3] + p[1]) / h
masks[len(boxes)] = 1.0
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
positive_embeddings += [p[0]]
append_boxes = []
append_conds = []
if len(boxes) < self.max_objs:
append_boxes = [torch.zeros(
[self.max_objs - len(boxes), 4], device="cpu")]
append_conds = [torch.zeros(
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
box_out = torch.cat(
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
masks = masks.unsqueeze(0).repeat(batch, 1)
conds = torch.cat(positive_embeddings +
append_conds).unsqueeze(0).repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def set_empty(self, latent_image_shape, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
box_out = torch.zeros([self.max_objs, 4],
device="cpu").repeat(batch, 1, 1)
conds = torch.zeros([self.max_objs, self.key_dim],
device="cpu").repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def cleanup(self):
pass
def get_models(self):
return [self]
def load_gligen(sd):
sd_k = sd.keys()
output_list = []
key_dim = 768
for a in ["input_blocks", "middle_block", "output_blocks"]:
for b in range(20):
k_temp = filter(lambda k: "{}.{}.".format(a, b)
in k and ".fuser." in k, sd_k)
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
n_sd = {}
for k in k_temp:
n_sd[k[1]] = sd[k[0]]
if len(n_sd) > 0:
query_dim = n_sd["linear.weight"].shape[0]
key_dim = n_sd["linear.weight"].shape[1]
if key_dim == 768: # SD1.x
n_heads = 8
d_head = query_dim // n_heads
else:
d_head = 64
n_heads = query_dim // d_head
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
in_dim = sd["position_net.null_positive_feature"].shape[0]
out_dim = sd["position_net.linears.4.weight"].shape[0]
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen

16
comfy/ldm/modules/attention.py

@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}): def _forward(self, x, context=None, transformer_options={}):
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
n = self.norm1(x) n = self.norm1(x)
if "tomesd" in transformer_options: if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
n = self.attn1(n, context=context if self.disable_self_attn else None) n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
n = self.norm2(x) n = self.norm2(x)
n = self.attn2(n, context=context) n = self.attn2(n, context=context)
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x return x

2
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -782,6 +782,8 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape) transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"

6
comfy/model_management.py

@ -176,7 +176,7 @@ def load_model_gpu(model):
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models): def load_controlnet_gpu(control_models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
@ -186,6 +186,10 @@ def load_controlnet_gpu(models):
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
models = []
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets: for m in current_gpu_controlnets:
if m not in models: if m not in models:
m.cpu() m.cpu()

57
comfy/samplers.py

@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
control = None control = None
if 'control' in cond[1]: if 'control' in cond[1]:
control = cond[1]['control'] control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape: if c1[0].shape != c2[0].shape:
return False return False
#control
if (c1[4] is None) != (c2[4] is None): if (c1[4] is None) != (c2[4] is None):
return False return False
if c1[4] is not None: if c1[4] is not None:
if c1[4] is not c2[4]: if c1[4] is not c2[4]:
return False return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2]) return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list): def cond_cat(c_list):
@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cond_or_uncond = [] cond_or_uncond = []
area = [] area = []
control = None control = None
patches = None
for x in to_batch: for x in to_batch:
o = to_run.pop(x) o = to_run.pop(x)
p = o[0] p = o[0]
@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
area += [p[3]] area += [p[3]]
cond_or_uncond += [o[1]] cond_or_uncond += [o[1]]
control = p[4] control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)
@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options'] transformer_options = model_options['transformer_options'].copy()
if patches is not None:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy() n = c[1].copy()
conds += [[smallest[0], n]] conds += [[smallest[0], n]]
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
def apply_control_net_to_equal_area(conds, uncond):
cond_cnets = [] cond_cnets = []
cond_other = [] cond_other = []
uncond_cnets = [] uncond_cnets = []
@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
if 'area' not in x[1]: if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None: if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1]['control']) cond_cnets.append(x[1][name])
else: else:
cond_other.append((x, t)) cond_other.append((x, t))
for t in range(len(uncond)): for t in range(len(uncond)):
x = uncond[t] x = uncond[t]
if 'area' not in x[1]: if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None: if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1]['control']) uncond_cnets.append(x[1][name])
else: else:
uncond_other.append((x, t)) uncond_other.append((x, t))
@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond):
for x in range(len(cond_cnets)): for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)] temp = uncond_other[x % len(uncond_other)]
o = temp[0] o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None: if name in o[1] and o[1][name] is not None:
n = o[1].copy() n = o[1].copy()
n['control'] = cond_cnets[x] n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]] uncond += [[o[0], n]]
else: else:
n = o[1].copy() n = o[1].copy()
n['control'] = cond_cnets[x] n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device): def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds return conds
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
@ -466,7 +498,8 @@ class KSampler:
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative) apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16: if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast

22
comfy/sd.py

@ -13,6 +13,7 @@ from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision from . import clip_vision
from . import gligen
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -378,7 +379,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens): def encode_from_tokens(self, tokens, return_pooled=False):
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try: try:
@ -388,6 +389,10 @@ class CLIP:
except Exception as e: except Exception as e:
self.patcher.unpatch_model() self.patcher.unpatch_model()
raise e raise e
if return_pooled:
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
pooled = cond[:, eos_token_index]
return cond, pooled
return cond return cond
def encode(self, text): def encode(self, text):
@ -564,10 +569,10 @@ class ControlNet:
c.strength = self.strength c.strength = self.strength
return c return c
def get_control_models(self): def get_models(self):
out = [] out = []
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models() out += self.previous_controlnet.get_models()
out.append(self.control_model) out.append(self.control_model)
return out return out
@ -737,10 +742,10 @@ class T2IAdapter:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
def get_control_models(self): def get_models(self):
out = [] out = []
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models() out += self.previous_controlnet.get_models()
return out return out
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None):
clip.load_from_state_dict(clip_data) clip.load_from_state_dict(clip_data)
return clip return clip
def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path)
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
with open(config_path, 'r') as stream: with open(config_path, 'r') as stream:
config = yaml.safe_load(stream) config = yaml.safe_load(stream)

2
folder_paths.py

@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])

0
models/gligen/put_gligen_models_here

71
nodes.py

@ -490,6 +490,51 @@ class unCLIPConditioning:
c.append(n) c.append(n)
return (c, ) return (c, )
class GLIGENLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
RETURN_TYPES = ("GLIGEN",)
FUNCTION = "load_gligen"
CATEGORY = "_for_testing/gligen"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)
class GLIGENTextBoxApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
"clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ),
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "_for_testing/gligen"
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
c.append(n)
return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
negative_copy = [] negative_copy = []
control_nets = [] control_nets = []
def get_models(cond):
models = []
for c in cond:
if 'control' in c[1]:
models += [c[1]['control']]
if 'gligen' in c[1]:
models += [c[1]['gligen'][1]]
return models
for p in positive: for p in positive:
t = p[0] t = p[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]] positive_copy += [[t] + p[1:]]
for n in negative: for n in negative:
t = n[0] t = n[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in n[1]:
control_nets += [n[1]['control']]
negative_copy += [[t] + n[1:]] negative_copy += [[t] + n[1:]]
control_net_models = [] models = get_models(positive) + get_models(negative)
for x in control_nets: comfy.model_management.load_controlnet_gpu(models)
control_net_models += x.get_control_models()
comfy.model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu() samples = samples.cpu()
for c in control_nets: for m in models:
c.cleanup() m.cleanup()
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
} }

Loading…
Cancel
Save