|
|
@ -50,18 +50,22 @@ def convert_to_transformers(sd, prefix): |
|
|
|
if "{}proj".format(prefix) in sd_k: |
|
|
|
if "{}proj".format(prefix) in sd_k: |
|
|
|
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) |
|
|
|
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) |
|
|
|
|
|
|
|
|
|
|
|
sd = transformers_convert(sd, prefix, "vision_model.", 32) |
|
|
|
sd = transformers_convert(sd, prefix, "vision_model.", 48) |
|
|
|
return sd |
|
|
|
return sd |
|
|
|
|
|
|
|
|
|
|
|
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): |
|
|
|
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): |
|
|
|
if convert_keys: |
|
|
|
if convert_keys: |
|
|
|
sd = convert_to_transformers(sd, prefix) |
|
|
|
sd = convert_to_transformers(sd, prefix) |
|
|
|
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: |
|
|
|
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd: |
|
|
|
|
|
|
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") |
|
|
|
|
|
|
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: |
|
|
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") |
|
|
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") |
|
|
|
else: |
|
|
|
else: |
|
|
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") |
|
|
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") |
|
|
|
clip = ClipVisionModel(json_config) |
|
|
|
clip = ClipVisionModel(json_config) |
|
|
|
m, u = clip.load_sd(sd) |
|
|
|
m, u = clip.load_sd(sd) |
|
|
|
|
|
|
|
if len(m) > 0: |
|
|
|
|
|
|
|
print("missing clip vision:", m) |
|
|
|
u = set(u) |
|
|
|
u = set(u) |
|
|
|
keys = list(sd.keys()) |
|
|
|
keys = list(sd.keys()) |
|
|
|
for k in keys: |
|
|
|
for k in keys: |
|
|
@ -72,4 +76,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): |
|
|
|
|
|
|
|
|
|
|
|
def load(ckpt_path): |
|
|
|
def load(ckpt_path): |
|
|
|
sd = load_torch_file(ckpt_path) |
|
|
|
sd = load_torch_file(ckpt_path) |
|
|
|
return load_clipvision_from_sd(sd) |
|
|
|
if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd: |
|
|
|
|
|
|
|
return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return load_clipvision_from_sd(sd) |
|
|
|