|
|
|
@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
|
|
|
|
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() |
|
|
|
|
|
|
|
|
|
replace_prefix = {} |
|
|
|
|
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) |
|
|
|
|
replace_prefix["cond_stage_model."] = "clip_l." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict): |
|
|
|
@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
|
|
|
|
|
|
|
|
|
|
def process_clip_state_dict(self, state_dict): |
|
|
|
|
replace_prefix = {} |
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) |
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) |
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format |
|
|
|
|
replace_prefix["cond_stage_model.model."] = "clip_h." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict): |
|
|
|
@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
|
|
|
|
def process_clip_state_dict(self, state_dict): |
|
|
|
|
keys_to_replace = {} |
|
|
|
|
replace_prefix = {} |
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "clip_g." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) |
|
|
|
|
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" |
|
|
|
|
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" |
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) |
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
|
|
|
|
|
keys_to_replace = {} |
|
|
|
|
replace_prefix = {} |
|
|
|
|
|
|
|
|
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" |
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) |
|
|
|
|
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" |
|
|
|
|
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" |
|
|
|
|
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" |
|
|
|
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model" |
|
|
|
|
replace_prefix["conditioner.embedders.1.model."] = "clip_g." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) |
|
|
|
|
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" |
|
|
|
|
|
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) |
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|