Browse Source

Playground V2.5 support with ModelSamplingContinuousEDM node.

Use ModelSamplingContinuousEDM with edm_playground_v2.5 selected.
pull/2918/head
comfyanonymous 9 months ago
parent
commit
d46583ecec
  1. 27
      comfy/latent_formats.py
  2. 13
      comfy/model_sampling.py
  3. 2
      comfy/samplers.py
  4. 13
      comfy_extras/nodes_model_advanced.py

27
comfy/latent_formats.py

@ -1,3 +1,4 @@
import torch
class LatentFormat: class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
@ -34,6 +35,32 @@ class SDXL(LatentFormat):
] ]
self.taesd_decoder_name = "taesdxl_decoder" self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
def __init__(self):
self.scale_factor = 0.5
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class SD_X4(LatentFormat): class SD_X4(LatentFormat):
def __init__(self): def __init__(self):
self.scale_factor = 0.08333 self.scale_factor = 0.08333

13
comfy/model_sampling.py

@ -17,6 +17,11 @@ class V_PREDICTION(EPS):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
@ -92,8 +97,6 @@ class ModelSamplingDiscrete(torch.nn.Module):
class ModelSamplingContinuousEDM(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
super().__init__() super().__init__()
self.sigma_data = 1.0
if model_config is not None: if model_config is not None:
sampling_settings = model_config.sampling_settings sampling_settings = model_config.sampling_settings
else: else:
@ -101,9 +104,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0) sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max) sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_parameters(sigma_min, sigma_max, sigma_data)
def set_sigma_range(self, sigma_min, sigma_max): def set_parameters(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers

2
comfy/samplers.py

@ -588,7 +588,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive) calculate_start_end_timesteps(model, positive)
if latent_image is not None: if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image) latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'): if hasattr(model, 'extra_conds'):

13
comfy_extras/nodes_model_advanced.py

@ -1,6 +1,7 @@
import folder_paths import folder_paths
import comfy.sd import comfy.sd
import comfy.model_sampling import comfy.model_sampling
import comfy.latent_formats
import torch import torch
class LCM(comfy.model_sampling.EPS): class LCM(comfy.model_sampling.EPS):
@ -135,7 +136,7 @@ class ModelSamplingContinuousEDM:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],), "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}} }}
@ -148,17 +149,25 @@ class ModelSamplingContinuousEDM:
def patch(self, model, sampling, sigma_max, sigma_min): def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone() m = model.clone()
latent_format = None
sigma_data = 1.0
if sampling == "eps": if sampling == "eps":
sampling_type = comfy.model_sampling.EPS sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction": elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "edm_playground_v2.5":
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max) model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
if latent_format is not None:
m.add_object_patch("latent_format", latent_format)
return (m, ) return (m, )
class RescaleCFG: class RescaleCFG:

Loading…
Cancel
Save