|
|
|
@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Adapter(nn.Module): |
|
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): |
|
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): |
|
|
|
|
super(Adapter, self).__init__() |
|
|
|
|
self.unshuffle = nn.PixelUnshuffle(8) |
|
|
|
|
unshuffle = 8 |
|
|
|
|
resblock_no_downsample = [] |
|
|
|
|
resblock_downsample = [3, 2, 1] |
|
|
|
|
self.xl = xl |
|
|
|
|
if self.xl: |
|
|
|
|
unshuffle = 16 |
|
|
|
|
resblock_no_downsample = [1] |
|
|
|
|
resblock_downsample = [2] |
|
|
|
|
|
|
|
|
|
self.input_channels = cin // (unshuffle * unshuffle) |
|
|
|
|
self.unshuffle = nn.PixelUnshuffle(unshuffle) |
|
|
|
|
self.channels = channels |
|
|
|
|
self.nums_rb = nums_rb |
|
|
|
|
self.body = [] |
|
|
|
|
for i in range(len(channels)): |
|
|
|
|
for j in range(nums_rb): |
|
|
|
|
if (i != 0) and (j == 0): |
|
|
|
|
if (i in resblock_downsample) and (j == 0): |
|
|
|
|
self.body.append( |
|
|
|
|
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) |
|
|
|
|
elif (i in resblock_no_downsample) and (j == 0): |
|
|
|
|
self.body.append( |
|
|
|
|
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) |
|
|
|
|
else: |
|
|
|
|
self.body.append( |
|
|
|
|
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) |
|
|
|
@ -128,8 +141,16 @@ class Adapter(nn.Module):
|
|
|
|
|
for j in range(self.nums_rb): |
|
|
|
|
idx = i * self.nums_rb + j |
|
|
|
|
x = self.body[idx](x) |
|
|
|
|
features.append(None) |
|
|
|
|
features.append(None) |
|
|
|
|
if self.xl: |
|
|
|
|
features.append(None) |
|
|
|
|
if i == 0: |
|
|
|
|
features.append(None) |
|
|
|
|
features.append(None) |
|
|
|
|
if i == 2: |
|
|
|
|
features.append(None) |
|
|
|
|
else: |
|
|
|
|
features.append(None) |
|
|
|
|
features.append(None) |
|
|
|
|
features.append(x) |
|
|
|
|
|
|
|
|
|
return features |
|
|
|
@ -243,10 +264,14 @@ class extractor(nn.Module):
|
|
|
|
|
class Adapter_light(nn.Module): |
|
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): |
|
|
|
|
super(Adapter_light, self).__init__() |
|
|
|
|
self.unshuffle = nn.PixelUnshuffle(8) |
|
|
|
|
unshuffle = 8 |
|
|
|
|
self.unshuffle = nn.PixelUnshuffle(unshuffle) |
|
|
|
|
self.input_channels = cin // (unshuffle * unshuffle) |
|
|
|
|
self.channels = channels |
|
|
|
|
self.nums_rb = nums_rb |
|
|
|
|
self.body = [] |
|
|
|
|
self.xl = False |
|
|
|
|
|
|
|
|
|
for i in range(len(channels)): |
|
|
|
|
if i == 0: |
|
|
|
|
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) |
|
|
|
|