You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.5 KiB
78 lines
2.5 KiB
import torch |
|
from . import model_base |
|
from . import utils |
|
|
|
|
|
def state_dict_key_replace(state_dict, keys_to_replace): |
|
for x in keys_to_replace: |
|
if x in state_dict: |
|
state_dict[keys_to_replace[x]] = state_dict.pop(x) |
|
return state_dict |
|
|
|
def state_dict_prefix_replace(state_dict, replace_prefix): |
|
for rp in replace_prefix: |
|
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) |
|
for x in replace: |
|
state_dict[x[1]] = state_dict.pop(x[0]) |
|
return state_dict |
|
|
|
|
|
class ClipTarget: |
|
def __init__(self, tokenizer, clip): |
|
self.clip = clip |
|
self.tokenizer = tokenizer |
|
self.params = {} |
|
|
|
class BASE: |
|
unet_config = {} |
|
unet_extra_config = { |
|
"num_heads": -1, |
|
"num_head_channels": 64, |
|
} |
|
|
|
clip_prefix = [] |
|
clip_vision_prefix = None |
|
noise_aug_config = None |
|
|
|
@classmethod |
|
def matches(s, unet_config): |
|
for k in s.unet_config: |
|
if s.unet_config[k] != unet_config[k]: |
|
return False |
|
return True |
|
|
|
def v_prediction(self, state_dict, prefix=""): |
|
return False |
|
|
|
def inpaint_model(self): |
|
return self.unet_config["in_channels"] > 4 |
|
|
|
def __init__(self, unet_config): |
|
self.unet_config = unet_config |
|
self.latent_format = self.latent_format() |
|
for x in self.unet_extra_config: |
|
self.unet_config[x] = self.unet_extra_config[x] |
|
|
|
def get_model(self, state_dict, prefix=""): |
|
if self.inpaint_model(): |
|
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix)) |
|
elif self.noise_aug_config is not None: |
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix)) |
|
else: |
|
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix)) |
|
|
|
def process_clip_state_dict(self, state_dict): |
|
return state_dict |
|
|
|
def process_clip_state_dict_for_saving(self, state_dict): |
|
replace_prefix = {"": "cond_stage_model."} |
|
return state_dict_prefix_replace(state_dict, replace_prefix) |
|
|
|
def process_unet_state_dict_for_saving(self, state_dict): |
|
replace_prefix = {"": "model.diffusion_model."} |
|
return state_dict_prefix_replace(state_dict, replace_prefix) |
|
|
|
def process_vae_state_dict_for_saving(self, state_dict): |
|
replace_prefix = {"": "first_stage_model."} |
|
return state_dict_prefix_replace(state_dict, replace_prefix) |
|
|
|
|