|
|
|
@ -75,7 +75,7 @@ class SD20(supported_models_base.BASE):
|
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format |
|
|
|
|
replace_prefix["cond_stage_model.model."] = "clip_h." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) |
|
|
|
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict): |
|
|
|
@ -134,7 +134,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "clip_g." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) |
|
|
|
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") |
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
@ -182,10 +182,8 @@ class SDXL(supported_models_base.BASE):
|
|
|
|
|
replace_prefix["conditioner.embedders.1.model."] = "clip_g." |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) |
|
|
|
|
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" |
|
|
|
|
|
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) |
|
|
|
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict): |
|
|
|
@ -338,6 +336,12 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|
|
|
|
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
def process_clip_state_dict(self, state_dict): |
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) |
|
|
|
|
if "clip_g.text_projection" in state_dict: |
|
|
|
|
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1) |
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
def get_model(self, state_dict, prefix="", device=None): |
|
|
|
|
out = model_base.StableCascade_C(self, device=device) |
|
|
|
|
return out |
|
|
|
|