Browse Source

Implement Perp-Neg

pull/2303/head
Hari 11 months ago
parent
commit
574363a8a6
  1. 3
      comfy/samplers.py
  2. 58
      comfy_extras/nodes_perpneg.py
  3. 1
      nodes.py

3
comfy/samplers.py

@ -251,7 +251,8 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale

58
comfy_extras/nodes_perpneg.py

@ -0,0 +1,58 @@
import torch
import comfy.model_management
import comfy.sample
import comfy.samplers
import comfy.utils
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"clip": ("CLIP", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, clip, neg_scale):
m = model.clone()
tokens = clip.tokenize("")
nocond, nocond_pooled = clip.encode_from_tokens(tokens, return_pooled=True)
nocond = [[nocond, {"pooled_output": nocond_pooled}]]
nocond = comfy.sample.convert_cond(nocond)
def cfg_function(args):
model = args["model"]
noise_pred_pos = args["cond_denoised"]
noise_pred_neg = args["uncond_denoised"]
cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"]
model_options = args["model_options"]
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond, None, x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg",
}

1
nodes.py

@ -1868,6 +1868,7 @@ def init_custom_nodes():
"nodes_images.py",
"nodes_video_model.py",
"nodes_sag.py",
"nodes_perpneg.py",
]
for node_file in extras_files:

Loading…
Cancel
Save