Browse Source

sampler_cfg_function now uses a dict for the argument.

This means arguments can be added without issues.
pull/770/head
comfyanonymous 1 year ago
parent
commit
388567f20b
  1. 3
      comfy/samplers.py
  2. 5
      comfy/sd.py

3
comfy/samplers.py

@ -273,7 +273,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale) args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
return model_options["sampler_cfg_function"](args)
else: else:
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale

5
comfy/sd.py

@ -1,6 +1,7 @@
import torch import torch
import contextlib import contextlib
import copy import copy
import inspect
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
@ -313,9 +314,11 @@ class ModelPatcher:
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_patch(self, patch, name): def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" not in to: if "patches" not in to:

Loading…
Cancel
Save