diff --git a/comfy/samplers.py b/comfy/samplers.py index 9eee25a9..5b01d48f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -17,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 + if 'timestep_start' in cond[1]: + timestep_start = cond[1]['timestep_start'] + if timestep_in > timestep_start: + return None + if 'timestep_end' in cond[1]: + timestep_end = cond[1]['timestep_end'] + if timestep_in < timestep_end: + return None if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -428,6 +436,25 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] +def calculate_start_end_timesteps(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + if 'start_percent' in x[1]: + timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) + if 'end_percent' in x[1]: + timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + + if (timestep_start is not None) or (timestep_end is not None): + n = x[1].copy() + if (timestep_start is not None): + n['timestep_start'] = timestep_start + if (timestep_end is not None): + n['timestep_end'] = timestep_end + conds[t] = [x[0], n] + def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] @@ -571,6 +598,9 @@ class KSampler: resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + calculate_start_end_timesteps(self.model_wrap, negative) + calculate_start_end_timesteps(self.model_wrap, positive) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) diff --git a/nodes.py b/nodes.py index a1c4b843..db39c0ce 100644 --- a/nodes.py +++ b/nodes.py @@ -204,6 +204,28 @@ class ConditioningZeroOut: c.append(n) return (c, ) +class ConditioningSetTimestepRange: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "set_range" + + CATEGORY = "advanced/conditioning" + + def set_range(self, conditioning, start, end): + c = [] + for t in conditioning: + d = t[1].copy() + d['start_percent'] = start + d['end_percent'] = end + n = [t[0], d] + c.append(n) + return (c, ) + class VAEDecode: @classmethod def INPUT_TYPES(s): @@ -1444,6 +1466,7 @@ NODE_CLASS_MAPPINGS = { "SaveLatent": SaveLatent, "ConditioningZeroOut": ConditioningZeroOut, + "ConditioningSetTimestepRange": ConditioningSetTimestepRange, } NODE_DISPLAY_NAME_MAPPINGS = {