|
|
@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module): |
|
|
|
self.adm_channels = unet_config.get("adm_in_channels", None) |
|
|
|
self.adm_channels = unet_config.get("adm_in_channels", None) |
|
|
|
if self.adm_channels is None: |
|
|
|
if self.adm_channels is None: |
|
|
|
self.adm_channels = 0 |
|
|
|
self.adm_channels = 0 |
|
|
|
|
|
|
|
self.inpaint_model = False |
|
|
|
print("model_type", model_type.name) |
|
|
|
print("model_type", model_type.name) |
|
|
|
print("adm", self.adm_channels) |
|
|
|
print("adm", self.adm_channels) |
|
|
|
|
|
|
|
|
|
|
@ -71,6 +72,37 @@ class BaseModel(torch.nn.Module): |
|
|
|
def encode_adm(self, **kwargs): |
|
|
|
def encode_adm(self, **kwargs): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cond_concat(self, **kwargs): |
|
|
|
|
|
|
|
if self.inpaint_model: |
|
|
|
|
|
|
|
concat_keys = ("mask", "masked_image") |
|
|
|
|
|
|
|
cond_concat = [] |
|
|
|
|
|
|
|
denoise_mask = kwargs.get("denoise_mask", None) |
|
|
|
|
|
|
|
latent_image = kwargs.get("latent_image", None) |
|
|
|
|
|
|
|
noise = kwargs.get("noise", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def blank_inpaint_image_like(latent_image): |
|
|
|
|
|
|
|
blank_image = torch.ones_like(latent_image) |
|
|
|
|
|
|
|
# these are the values for "zero" in pixel space translated to latent space |
|
|
|
|
|
|
|
blank_image[:,0] *= 0.8223 |
|
|
|
|
|
|
|
blank_image[:,1] *= -0.6876 |
|
|
|
|
|
|
|
blank_image[:,2] *= 0.6364 |
|
|
|
|
|
|
|
blank_image[:,3] *= 0.1380 |
|
|
|
|
|
|
|
return blank_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for ck in concat_keys: |
|
|
|
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
|
|
|
if ck == "mask": |
|
|
|
|
|
|
|
cond_concat.append(denoise_mask[:,:1]) |
|
|
|
|
|
|
|
elif ck == "masked_image": |
|
|
|
|
|
|
|
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
if ck == "mask": |
|
|
|
|
|
|
|
cond_concat.append(torch.ones_like(noise)[:,:1]) |
|
|
|
|
|
|
|
elif ck == "masked_image": |
|
|
|
|
|
|
|
cond_concat.append(blank_inpaint_image_like(noise)) |
|
|
|
|
|
|
|
return cond_concat |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def load_model_weights(self, sd, unet_prefix=""): |
|
|
|
def load_model_weights(self, sd, unet_prefix=""): |
|
|
|
to_load = {} |
|
|
|
to_load = {} |
|
|
|
keys = list(sd.keys()) |
|
|
|
keys = list(sd.keys()) |
|
|
@ -112,7 +144,7 @@ class BaseModel(torch.nn.Module): |
|
|
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} |
|
|
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} |
|
|
|
|
|
|
|
|
|
|
|
def set_inpaint(self): |
|
|
|
def set_inpaint(self): |
|
|
|
self.concat_keys = ("mask", "masked_image") |
|
|
|
self.inpaint_model = True |
|
|
|
|
|
|
|
|
|
|
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): |
|
|
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): |
|
|
|
adm_inputs = [] |
|
|
|
adm_inputs = [] |
|
|
|