|
|
@ -103,17 +103,17 @@ class ResnetBlock(nn.Module): |
|
|
|
class Adapter(nn.Module): |
|
|
|
class Adapter(nn.Module): |
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): |
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): |
|
|
|
super(Adapter, self).__init__() |
|
|
|
super(Adapter, self).__init__() |
|
|
|
unshuffle = 8 |
|
|
|
self.unshuffle_amount = 8 |
|
|
|
resblock_no_downsample = [] |
|
|
|
resblock_no_downsample = [] |
|
|
|
resblock_downsample = [3, 2, 1] |
|
|
|
resblock_downsample = [3, 2, 1] |
|
|
|
self.xl = xl |
|
|
|
self.xl = xl |
|
|
|
if self.xl: |
|
|
|
if self.xl: |
|
|
|
unshuffle = 16 |
|
|
|
self.unshuffle_amount = 16 |
|
|
|
resblock_no_downsample = [1] |
|
|
|
resblock_no_downsample = [1] |
|
|
|
resblock_downsample = [2] |
|
|
|
resblock_downsample = [2] |
|
|
|
|
|
|
|
|
|
|
|
self.input_channels = cin // (unshuffle * unshuffle) |
|
|
|
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount) |
|
|
|
self.unshuffle = nn.PixelUnshuffle(unshuffle) |
|
|
|
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount) |
|
|
|
self.channels = channels |
|
|
|
self.channels = channels |
|
|
|
self.nums_rb = nums_rb |
|
|
|
self.nums_rb = nums_rb |
|
|
|
self.body = [] |
|
|
|
self.body = [] |
|
|
@ -264,9 +264,9 @@ class extractor(nn.Module): |
|
|
|
class Adapter_light(nn.Module): |
|
|
|
class Adapter_light(nn.Module): |
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): |
|
|
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): |
|
|
|
super(Adapter_light, self).__init__() |
|
|
|
super(Adapter_light, self).__init__() |
|
|
|
unshuffle = 8 |
|
|
|
self.unshuffle_amount = 8 |
|
|
|
self.unshuffle = nn.PixelUnshuffle(unshuffle) |
|
|
|
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount) |
|
|
|
self.input_channels = cin // (unshuffle * unshuffle) |
|
|
|
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount) |
|
|
|
self.channels = channels |
|
|
|
self.channels = channels |
|
|
|
self.nums_rb = nums_rb |
|
|
|
self.nums_rb = nums_rb |
|
|
|
self.body = [] |
|
|
|
self.body = [] |
|
|
|