|
|
|
@ -319,6 +319,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|
|
|
|
"shift": 2.0, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
vae_key_prefix = ["vae."] |
|
|
|
|
text_encoder_key_prefix = ["text_encoder."] |
|
|
|
|
clip_vision_prefix = "clip_l_vision." |
|
|
|
|
|
|
|
|
|
def process_unet_state_dict(self, state_dict): |
|
|
|
|
key_list = list(state_dict.keys()) |
|
|
|
|
for y in ["weight", "bias"]: |
|
|
|
@ -355,6 +359,8 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|
|
|
|
"shift": 1.0, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
clip_vision_prefix = None |
|
|
|
|
|
|
|
|
|
def get_model(self, state_dict, prefix="", device=None): |
|
|
|
|
out = model_base.StableCascade_B(self, device=device) |
|
|
|
|
return out |
|
|
|
|