diff --git a/comfy/model_base.py b/comfy/model_base.py index 5da71e63..bc019de5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -380,6 +380,36 @@ class SVD_img2vid(BaseModel): out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out +class SV3D_u(SVD_img2vid): + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + +class SV3D_p(SVD_img2vid): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder_512 = Timestep(512) + + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + azimuth = kwargs.get("azimuth", 0) + noise = kwargs.get("noise", None) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) + + out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) + return torch.cat(out, dim=1) + + class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e12935a2..2ce9736b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -284,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None +class SV3D_u(SVD_img2vid): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 256, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + vae_key_prefix = ["conditioner.embedders.1.encoder."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_u(self, device=device) + return out + +class SV3D_p(SV3D_u): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 1280, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_p(self, device=device) + return out + class Stable_Zero123(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -405,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C): return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index 4375d8f9..be2e34c2 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -29,8 +29,8 @@ class StableZero123_Conditioning: "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -62,10 +62,10 @@ class StableZero123_Conditioning_Batched: "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -95,8 +95,49 @@ class StableZero123_Conditioning_Batched: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, } diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index a5262565..1a0189ed 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -79,6 +79,33 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): CATEGORY = "_for_testing" @@ -98,6 +125,7 @@ NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, }