|
|
@ -1,7 +1,7 @@ |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel |
|
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep |
|
|
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation |
|
|
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation |
|
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep |
|
|
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation |
|
|
|
import comfy.model_management |
|
|
|
import comfy.model_management |
|
|
|
import comfy.conds |
|
|
|
import comfy.conds |
|
|
|
import comfy.ops |
|
|
|
import comfy.ops |
|
|
@ -78,8 +78,9 @@ class BaseModel(torch.nn.Module): |
|
|
|
extra_conds = {} |
|
|
|
extra_conds = {} |
|
|
|
for o in kwargs: |
|
|
|
for o in kwargs: |
|
|
|
extra = kwargs[o] |
|
|
|
extra = kwargs[o] |
|
|
|
if hasattr(extra, "to"): |
|
|
|
if hasattr(extra, "dtype"): |
|
|
|
extra = extra.to(dtype) |
|
|
|
if extra.dtype != torch.int and extra.dtype != torch.long: |
|
|
|
|
|
|
|
extra = extra.to(dtype) |
|
|
|
extra_conds[o] = extra |
|
|
|
extra_conds[o] = extra |
|
|
|
|
|
|
|
|
|
|
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() |
|
|
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() |
|
|
@ -368,20 +369,31 @@ class Stable_Zero123(BaseModel): |
|
|
|
class SD_X4Upscaler(BaseModel): |
|
|
|
class SD_X4Upscaler(BaseModel): |
|
|
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): |
|
|
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): |
|
|
|
super().__init__(model_config, model_type, device=device) |
|
|
|
super().__init__(model_config, model_type, device=device) |
|
|
|
|
|
|
|
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) |
|
|
|
|
|
|
|
|
|
|
|
def extra_conds(self, **kwargs): |
|
|
|
def extra_conds(self, **kwargs): |
|
|
|
out = {} |
|
|
|
out = {} |
|
|
|
|
|
|
|
|
|
|
|
image = kwargs.get("concat_image", None) |
|
|
|
image = kwargs.get("concat_image", None) |
|
|
|
noise = kwargs.get("noise", None) |
|
|
|
noise = kwargs.get("noise", None) |
|
|
|
|
|
|
|
noise_augment = kwargs.get("noise_augmentation", 0.0) |
|
|
|
|
|
|
|
device = kwargs["device"] |
|
|
|
|
|
|
|
seed = kwargs["seed"] - 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) |
|
|
|
|
|
|
|
|
|
|
|
if image is None: |
|
|
|
if image is None: |
|
|
|
image = torch.zeros_like(noise)[:,:3] |
|
|
|
image = torch.zeros_like(noise)[:,:3] |
|
|
|
|
|
|
|
|
|
|
|
if image.shape[1:] != noise.shape[1:]: |
|
|
|
if image.shape[1:] != noise.shape[1:]: |
|
|
|
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") |
|
|
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise_level = torch.tensor([noise_level], device=device) |
|
|
|
|
|
|
|
if noise_augment > 0: |
|
|
|
|
|
|
|
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) |
|
|
|
|
|
|
|
|
|
|
|
image = utils.resize_to_batch_size(image, noise.shape[0]) |
|
|
|
image = utils.resize_to_batch_size(image, noise.shape[0]) |
|
|
|
|
|
|
|
|
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image) |
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image) |
|
|
|
|
|
|
|
out['y'] = comfy.conds.CONDRegular(noise_level) |
|
|
|
return out |
|
|
|
return out |
|
|
|