|
|
|
@ -48,7 +48,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|
|
|
|
unet_config["dtype"] = dtype |
|
|
|
|
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] |
|
|
|
|
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] |
|
|
|
|
out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0] |
|
|
|
|
|
|
|
|
|
out_key = '{}out.2.weight'.format(key_prefix) |
|
|
|
|
if out_key in state_dict: |
|
|
|
|
out_channels = state_dict[out_key].shape[0] |
|
|
|
|
else: |
|
|
|
|
out_channels = 4 |
|
|
|
|
|
|
|
|
|
num_res_blocks = [] |
|
|
|
|
channel_mult = [] |
|
|
|
|