Browse Source

Add missing samplers to KSamplerSelect.

pull/1622/head
comfyanonymous 1 year ago
parent
commit
d234ca558a
  1. 20
      comfy/samplers.py
  2. 4
      comfy_extras/nodes_custom_sampler.py

20
comfy/samplers.py

@ -711,6 +711,17 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
print("error invalid scheduler", self.scheduler)
return sigmas
def sampler_class(name):
if name == "uni_pc":
sampler = UNIPC
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
elif name == "ddim":
sampler = DDIM
else:
sampler = ksampler(name)
return sampler
class KSampler:
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
@ -769,13 +780,6 @@ class KSampler:
else:
return torch.zeros_like(noise)
if self.sampler == "uni_pc":
sampler = UNIPC
elif self.sampler == "uni_pc_bh2":
sampler = UNIPCBH2
elif self.sampler == "ddim":
sampler = DDIM
else:
sampler = ksampler(self.sampler)
sampler = sampler_class(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

4
comfy_extras/nodes_custom_sampler.py

@ -28,7 +28,7 @@ class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sampler_name": (comfy.samplers.KSAMPLER_NAMES, ),
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
}
}
RETURN_TYPES = ("SAMPLER",)
@ -37,7 +37,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
sampler = comfy.samplers.ksampler(sampler_name)()
sampler = comfy.samplers.sampler_class(sampler_name)()
return (sampler, )
class SamplerCustom:

Loading…
Cancel
Save