From d46583ecece5014f23f9f47f7952c8aecd8cc491 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 15:12:33 -0500 Subject: [PATCH] Playground V2.5 support with ModelSamplingContinuousEDM node. Use ModelSamplingContinuousEDM with edm_playground_v2.5 selected. --- comfy/latent_formats.py | 27 +++++++++++++++++++++++++++ comfy/model_sampling.py | 13 +++++++++---- comfy/samplers.py | 2 +- comfy_extras/nodes_model_advanced.py | 13 +++++++++++-- 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 03fd59e3..674364e7 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,3 +1,4 @@ +import torch class LatentFormat: scale_factor = 1.0 @@ -34,6 +35,32 @@ class SDXL(LatentFormat): ] self.taesd_decoder_name = "taesdxl_decoder" +class SDXL_Playground_2_5(LatentFormat): + def __init__(self): + self.scale_factor = 0.5 + self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) + self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) + + self.latent_rgb_factors = [ + # R G B + [ 0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [ 0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + self.taesd_decoder_name = "taesdxl_decoder" + + def process_in(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return (latent - latents_mean) * self.scale_factor / latents_std + + def process_out(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return latent * latents_std / self.scale_factor + latents_mean + + class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 97e91a01..e7f8bc6a 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -17,6 +17,11 @@ class V_PREDICTION(EPS): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class EDM(V_PREDICTION): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): @@ -92,8 +97,6 @@ class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - self.sigma_data = 1.0 - if model_config is not None: sampling_settings = model_config.sampling_settings else: @@ -101,9 +104,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_parameters(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_parameters(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers diff --git a/comfy/samplers.py b/comfy/samplers.py index 491c95d3..e5569322 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -588,7 +588,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) - if latent_image is not None: + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = model.process_latent_in(latent_image) if hasattr(model, 'extra_conds'): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 1b3f3945..21af4b73 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,6 +1,7 @@ import folder_paths import comfy.sd import comfy.model_sampling +import comfy.latent_formats import torch class LCM(comfy.model_sampling.EPS): @@ -135,7 +136,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "eps"],), + "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -148,17 +149,25 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + latent_format = None + sigma_data = 1.0 if sampling == "eps": sampling_type = comfy.model_sampling.EPS elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + latent_format = comfy.latent_formats.SDXL_Playground_2_5() class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_sigma_range(sigma_min, sigma_max) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) return (m, ) class RescaleCFG: