|
|
|
@ -46,6 +46,18 @@ def detect_unet_config(state_dict, key_prefix):
|
|
|
|
|
unet_config['c_cond'] = 2048 |
|
|
|
|
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: |
|
|
|
|
unet_config['stable_cascade_stage'] = 'b' |
|
|
|
|
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] |
|
|
|
|
if w.shape[-1] == 640: |
|
|
|
|
unet_config['c_hidden'] = [320, 640, 1280, 1280] |
|
|
|
|
unet_config['nhead'] = [-1, -1, 20, 20] |
|
|
|
|
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] |
|
|
|
|
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] |
|
|
|
|
elif w.shape[-1] == 576: #stage b lite |
|
|
|
|
unet_config['c_hidden'] = [320, 576, 1152, 1152] |
|
|
|
|
unet_config['nhead'] = [-1, 9, 18, 18] |
|
|
|
|
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] |
|
|
|
|
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] |
|
|
|
|
|
|
|
|
|
return unet_config |
|
|
|
|
|
|
|
|
|
unet_config = { |
|
|
|
|