diff --git a/comfy/model_base.py b/comfy/model_base.py index bc019de5..6f530d2f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module): self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 - self.inpaint_model = False + + self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) @@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module): def extra_conds(self, **kwargs): out = {} - if self.inpaint_model: - concat_keys = ("mask", "masked_image") + if len(self.concat_keys) > 0: cond_concat = [] denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) concat_latent_image = kwargs.get("concat_latent_image", None) @@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module): concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) - if len(denoise_mask.shape) == len(noise.shape): - denoise_mask = denoise_mask[:,:1] - - denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) - if denoise_mask.shape[-2:] != noise.shape[-2:]: - denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") - denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + if denoise_mask is not None: + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] - def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) - for ck in concat_keys: + for ck in self.concat_keys: if denoise_mask is not None: if ck == "mask": cond_concat.append(denoise_mask.to(device)) @@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module): if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) + cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = comfy.conds.CONDNoiseShape(data) @@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module): return unet_state_dict def set_inpaint(self): - self.inpaint_model = True + self.concat_keys = ("mask", "masked_image") + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + self.blank_inpaint_image_like = blank_inpaint_image_like def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():