|
|
|
@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
self.adm_channels = unet_config.get("adm_in_channels", None) |
|
|
|
|
if self.adm_channels is None: |
|
|
|
|
self.adm_channels = 0 |
|
|
|
|
self.inpaint_model = False |
|
|
|
|
|
|
|
|
|
self.concat_keys = () |
|
|
|
|
logging.info("model_type {}".format(model_type.name)) |
|
|
|
|
logging.debug("adm {}".format(self.adm_channels)) |
|
|
|
|
|
|
|
|
@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def extra_conds(self, **kwargs): |
|
|
|
|
out = {} |
|
|
|
|
if self.inpaint_model: |
|
|
|
|
concat_keys = ("mask", "masked_image") |
|
|
|
|
if len(self.concat_keys) > 0: |
|
|
|
|
cond_concat = [] |
|
|
|
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) |
|
|
|
|
concat_latent_image = kwargs.get("concat_latent_image", None) |
|
|
|
@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) |
|
|
|
|
|
|
|
|
|
if len(denoise_mask.shape) == len(noise.shape): |
|
|
|
|
denoise_mask = denoise_mask[:,:1] |
|
|
|
|
|
|
|
|
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) |
|
|
|
|
if denoise_mask.shape[-2:] != noise.shape[-2:]: |
|
|
|
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") |
|
|
|
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
if len(denoise_mask.shape) == len(noise.shape): |
|
|
|
|
denoise_mask = denoise_mask[:,:1] |
|
|
|
|
|
|
|
|
|
def blank_inpaint_image_like(latent_image): |
|
|
|
|
blank_image = torch.ones_like(latent_image) |
|
|
|
|
# these are the values for "zero" in pixel space translated to latent space |
|
|
|
|
blank_image[:,0] *= 0.8223 |
|
|
|
|
blank_image[:,1] *= -0.6876 |
|
|
|
|
blank_image[:,2] *= 0.6364 |
|
|
|
|
blank_image[:,3] *= 0.1380 |
|
|
|
|
return blank_image |
|
|
|
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) |
|
|
|
|
if denoise_mask.shape[-2:] != noise.shape[-2:]: |
|
|
|
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") |
|
|
|
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) |
|
|
|
|
|
|
|
|
|
for ck in concat_keys: |
|
|
|
|
for ck in self.concat_keys: |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
if ck == "mask": |
|
|
|
|
cond_concat.append(denoise_mask.to(device)) |
|
|
|
@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
if ck == "mask": |
|
|
|
|
cond_concat.append(torch.ones_like(noise)[:,:1]) |
|
|
|
|
elif ck == "masked_image": |
|
|
|
|
cond_concat.append(blank_inpaint_image_like(noise)) |
|
|
|
|
cond_concat.append(self.blank_inpaint_image_like(noise)) |
|
|
|
|
data = torch.cat(cond_concat, dim=1) |
|
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(data) |
|
|
|
|
|
|
|
|
@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
return unet_state_dict |
|
|
|
|
|
|
|
|
|
def set_inpaint(self): |
|
|
|
|
self.inpaint_model = True |
|
|
|
|
self.concat_keys = ("mask", "masked_image") |
|
|
|
|
def blank_inpaint_image_like(latent_image): |
|
|
|
|
blank_image = torch.ones_like(latent_image) |
|
|
|
|
# these are the values for "zero" in pixel space translated to latent space |
|
|
|
|
blank_image[:,0] *= 0.8223 |
|
|
|
|
blank_image[:,1] *= -0.6876 |
|
|
|
|
blank_image[:,2] *= 0.6364 |
|
|
|
|
blank_image[:,3] *= 0.1380 |
|
|
|
|
return blank_image |
|
|
|
|
self.blank_inpaint_image_like = blank_inpaint_image_like |
|
|
|
|
|
|
|
|
|
def memory_required(self, input_shape): |
|
|
|
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): |
|
|
|
|