diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index cb0a7983..ea936e06 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -498,7 +498,7 @@ class UNetModel(nn.Module): if self.num_classes is not None: if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 768a47f9..f5ac7c2f 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -41,8 +41,12 @@ class AbstractLowScaleModel(nn.Module): self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) + def q_sample(self, x_start, t, noise=None, seed=None): + if noise is None: + if seed is None: + noise = torch.randn_like(x_start) + else: + noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device) return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) @@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) return z, noise_level diff --git a/comfy/model_base.py b/comfy/model_base.py index 64a380ff..f5952620 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,7 +1,7 @@ import torch -from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep +from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management import comfy.conds import comfy.ops @@ -78,8 +78,9 @@ class BaseModel(torch.nn.Module): extra_conds = {} for o in kwargs: extra = kwargs[o] - if hasattr(extra, "to"): - extra = extra.to(dtype) + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) extra_conds[o] = extra model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() @@ -368,20 +369,31 @@ class Stable_Zero123(BaseModel): class SD_X4Upscaler(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device) + self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) def extra_conds(self, **kwargs): out = {} image = kwargs.get("concat_image", None) noise = kwargs.get("noise", None) + noise_augment = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + seed = kwargs["seed"] - 10 + + noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) if image is None: image = torch.zeros_like(noise)[:,:3] if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + noise_level = torch.tensor([noise_level], device=device) + if noise_augment > 0: + image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) image = utils.resize_to_batch_size(image, noise.shape[0]) out['c_concat'] = comfy.conds.CONDNoiseShape(image) + out['y'] = comfy.conds.CONDRegular(noise_level) return out diff --git a/comfy/samplers.py b/comfy/samplers.py index 0453c1f6..89d8d4f2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -603,8 +603,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model latent_image = model.process_latent_in(latent_image) if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) #make sure each cond area has an opposite one with the same area for c in positive: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e7a6cc17..1d442d4d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -290,6 +290,7 @@ class SD_X4Upscaler(SD20): unet_extra_config = { "disable_self_attentions": [True, True, True, False], + "num_classes": 1000, "num_heads": 8, "num_head_channels": -1, } diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index 38a027e0..28c1cb0f 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning: "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), - # "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -18,7 +18,7 @@ class SD_4XUpscale_Conditioning: CATEGORY = "conditioning/upscale_diffusion" - def encode(self, images, positive, negative, scale_ratio): + def encode(self, images, positive, negative, scale_ratio, noise_augmentation): width = max(1, round(images.shape[-2] * scale_ratio)) height = max(1, round(images.shape[-3] * scale_ratio)) @@ -30,11 +30,13 @@ class SD_4XUpscale_Conditioning: for t in positive: n = [t[0], t[1].copy()] n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation out_cp.append(n) for t in negative: n = [t[0], t[1].copy()] n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation out_cn.append(n) latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])