You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
546 lines
14 KiB
546 lines
14 KiB
#!/usr/bin/env python3 |
|
# -*- coding: utf-8 -*- |
|
|
|
from __future__ import annotations |
|
|
|
from collections import OrderedDict |
|
try: |
|
from typing import Literal |
|
except ImportError: |
|
from typing_extensions import Literal |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
#################### |
|
# Basic blocks |
|
#################### |
|
|
|
|
|
def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1): |
|
# helper selecting activation |
|
# neg_slope: for leakyrelu and init of prelu |
|
# n_prelu: for p_relu num_parameters |
|
act_type = act_type.lower() |
|
if act_type == "relu": |
|
layer = nn.ReLU(inplace) |
|
elif act_type == "leakyrelu": |
|
layer = nn.LeakyReLU(neg_slope, inplace) |
|
elif act_type == "prelu": |
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) |
|
else: |
|
raise NotImplementedError( |
|
"activation layer [{:s}] is not found".format(act_type) |
|
) |
|
return layer |
|
|
|
|
|
def norm(norm_type: str, nc: int): |
|
# helper selecting normalization layer |
|
norm_type = norm_type.lower() |
|
if norm_type == "batch": |
|
layer = nn.BatchNorm2d(nc, affine=True) |
|
elif norm_type == "instance": |
|
layer = nn.InstanceNorm2d(nc, affine=False) |
|
else: |
|
raise NotImplementedError( |
|
"normalization layer [{:s}] is not found".format(norm_type) |
|
) |
|
return layer |
|
|
|
|
|
def pad(pad_type: str, padding): |
|
# helper selecting padding layer |
|
# if padding is 'zero', do by conv layers |
|
pad_type = pad_type.lower() |
|
if padding == 0: |
|
return None |
|
if pad_type == "reflect": |
|
layer = nn.ReflectionPad2d(padding) |
|
elif pad_type == "replicate": |
|
layer = nn.ReplicationPad2d(padding) |
|
else: |
|
raise NotImplementedError( |
|
"padding layer [{:s}] is not implemented".format(pad_type) |
|
) |
|
return layer |
|
|
|
|
|
def get_valid_padding(kernel_size, dilation): |
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) |
|
padding = (kernel_size - 1) // 2 |
|
return padding |
|
|
|
|
|
class ConcatBlock(nn.Module): |
|
# Concat the output of a submodule to its input |
|
def __init__(self, submodule): |
|
super(ConcatBlock, self).__init__() |
|
self.sub = submodule |
|
|
|
def forward(self, x): |
|
output = torch.cat((x, self.sub(x)), dim=1) |
|
return output |
|
|
|
def __repr__(self): |
|
tmpstr = "Identity .. \n|" |
|
modstr = self.sub.__repr__().replace("\n", "\n|") |
|
tmpstr = tmpstr + modstr |
|
return tmpstr |
|
|
|
|
|
class ShortcutBlock(nn.Module): |
|
# Elementwise sum the output of a submodule to its input |
|
def __init__(self, submodule): |
|
super(ShortcutBlock, self).__init__() |
|
self.sub = submodule |
|
|
|
def forward(self, x): |
|
output = x + self.sub(x) |
|
return output |
|
|
|
def __repr__(self): |
|
tmpstr = "Identity + \n|" |
|
modstr = self.sub.__repr__().replace("\n", "\n|") |
|
tmpstr = tmpstr + modstr |
|
return tmpstr |
|
|
|
|
|
class ShortcutBlockSPSR(nn.Module): |
|
# Elementwise sum the output of a submodule to its input |
|
def __init__(self, submodule): |
|
super(ShortcutBlockSPSR, self).__init__() |
|
self.sub = submodule |
|
|
|
def forward(self, x): |
|
return x, self.sub |
|
|
|
def __repr__(self): |
|
tmpstr = "Identity + \n|" |
|
modstr = self.sub.__repr__().replace("\n", "\n|") |
|
tmpstr = tmpstr + modstr |
|
return tmpstr |
|
|
|
|
|
def sequential(*args): |
|
# Flatten Sequential. It unwraps nn.Sequential. |
|
if len(args) == 1: |
|
if isinstance(args[0], OrderedDict): |
|
raise NotImplementedError("sequential does not support OrderedDict input.") |
|
return args[0] # No sequential is needed. |
|
modules = [] |
|
for module in args: |
|
if isinstance(module, nn.Sequential): |
|
for submodule in module.children(): |
|
modules.append(submodule) |
|
elif isinstance(module, nn.Module): |
|
modules.append(module) |
|
return nn.Sequential(*modules) |
|
|
|
|
|
ConvMode = Literal["CNA", "NAC", "CNAC"] |
|
|
|
|
|
# 2x2x2 Conv Block |
|
def conv_block_2c2( |
|
in_nc, |
|
out_nc, |
|
act_type="relu", |
|
): |
|
return sequential( |
|
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), |
|
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), |
|
act(act_type) if act_type else None, |
|
) |
|
|
|
|
|
def conv_block( |
|
in_nc: int, |
|
out_nc: int, |
|
kernel_size, |
|
stride=1, |
|
dilation=1, |
|
groups=1, |
|
bias=True, |
|
pad_type="zero", |
|
norm_type: str | None = None, |
|
act_type: str | None = "relu", |
|
mode: ConvMode = "CNA", |
|
c2x2=False, |
|
): |
|
""" |
|
Conv layer with padding, normalization, activation |
|
mode: CNA --> Conv -> Norm -> Act |
|
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) |
|
""" |
|
|
|
if c2x2: |
|
return conv_block_2c2(in_nc, out_nc, act_type=act_type) |
|
|
|
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) |
|
padding = get_valid_padding(kernel_size, dilation) |
|
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None |
|
padding = padding if pad_type == "zero" else 0 |
|
|
|
c = nn.Conv2d( |
|
in_nc, |
|
out_nc, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
groups=groups, |
|
) |
|
a = act(act_type) if act_type else None |
|
if mode in ("CNA", "CNAC"): |
|
n = norm(norm_type, out_nc) if norm_type else None |
|
return sequential(p, c, n, a) |
|
elif mode == "NAC": |
|
if norm_type is None and act_type is not None: |
|
a = act(act_type, inplace=False) |
|
# Important! |
|
# input----ReLU(inplace)----Conv--+----output |
|
# |________________________| |
|
# inplace ReLU will modify the input, therefore wrong output |
|
n = norm(norm_type, in_nc) if norm_type else None |
|
return sequential(n, a, p, c) |
|
else: |
|
assert False, f"Invalid conv mode {mode}" |
|
|
|
|
|
#################### |
|
# Useful blocks |
|
#################### |
|
|
|
|
|
class ResNetBlock(nn.Module): |
|
""" |
|
ResNet Block, 3-3 style |
|
with extra residual scaling used in EDSR |
|
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_nc, |
|
mid_nc, |
|
out_nc, |
|
kernel_size=3, |
|
stride=1, |
|
dilation=1, |
|
groups=1, |
|
bias=True, |
|
pad_type="zero", |
|
norm_type=None, |
|
act_type="relu", |
|
mode: ConvMode = "CNA", |
|
res_scale=1, |
|
): |
|
super(ResNetBlock, self).__init__() |
|
conv0 = conv_block( |
|
in_nc, |
|
mid_nc, |
|
kernel_size, |
|
stride, |
|
dilation, |
|
groups, |
|
bias, |
|
pad_type, |
|
norm_type, |
|
act_type, |
|
mode, |
|
) |
|
if mode == "CNA": |
|
act_type = None |
|
if mode == "CNAC": # Residual path: |-CNAC-| |
|
act_type = None |
|
norm_type = None |
|
conv1 = conv_block( |
|
mid_nc, |
|
out_nc, |
|
kernel_size, |
|
stride, |
|
dilation, |
|
groups, |
|
bias, |
|
pad_type, |
|
norm_type, |
|
act_type, |
|
mode, |
|
) |
|
# if in_nc != out_nc: |
|
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ |
|
# None, None) |
|
# print('Need a projecter in ResNetBlock.') |
|
# else: |
|
# self.project = lambda x:x |
|
self.res = sequential(conv0, conv1) |
|
self.res_scale = res_scale |
|
|
|
def forward(self, x): |
|
res = self.res(x).mul(self.res_scale) |
|
return x + res |
|
|
|
|
|
class RRDB(nn.Module): |
|
""" |
|
Residual in Residual Dense Block |
|
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
nf, |
|
kernel_size=3, |
|
gc=32, |
|
stride=1, |
|
bias: bool = True, |
|
pad_type="zero", |
|
norm_type=None, |
|
act_type="leakyrelu", |
|
mode: ConvMode = "CNA", |
|
_convtype="Conv2D", |
|
_spectral_norm=False, |
|
plus=False, |
|
c2x2=False, |
|
): |
|
super(RRDB, self).__init__() |
|
self.RDB1 = ResidualDenseBlock_5C( |
|
nf, |
|
kernel_size, |
|
gc, |
|
stride, |
|
bias, |
|
pad_type, |
|
norm_type, |
|
act_type, |
|
mode, |
|
plus=plus, |
|
c2x2=c2x2, |
|
) |
|
self.RDB2 = ResidualDenseBlock_5C( |
|
nf, |
|
kernel_size, |
|
gc, |
|
stride, |
|
bias, |
|
pad_type, |
|
norm_type, |
|
act_type, |
|
mode, |
|
plus=plus, |
|
c2x2=c2x2, |
|
) |
|
self.RDB3 = ResidualDenseBlock_5C( |
|
nf, |
|
kernel_size, |
|
gc, |
|
stride, |
|
bias, |
|
pad_type, |
|
norm_type, |
|
act_type, |
|
mode, |
|
plus=plus, |
|
c2x2=c2x2, |
|
) |
|
|
|
def forward(self, x): |
|
out = self.RDB1(x) |
|
out = self.RDB2(out) |
|
out = self.RDB3(out) |
|
return out * 0.2 + x |
|
|
|
|
|
class ResidualDenseBlock_5C(nn.Module): |
|
""" |
|
Residual Dense Block |
|
style: 5 convs |
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) |
|
Modified options that can be used: |
|
- "Partial Convolution based Padding" arXiv:1811.11718 |
|
- "Spectral normalization" arXiv:1802.05957 |
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. |
|
{Rakotonirina} and A. {Rasoanaivo} |
|
|
|
Args: |
|
nf (int): Channel number of intermediate features (num_feat). |
|
gc (int): Channels for each growth (num_grow_ch: growth channel, |
|
i.e. intermediate channels). |
|
convtype (str): the type of convolution to use. Default: 'Conv2D' |
|
gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new |
|
trainable parameters) |
|
plus (bool): enable the additional residual paths from ESRGAN+ |
|
(adds trainable parameters) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
nf=64, |
|
kernel_size=3, |
|
gc=32, |
|
stride=1, |
|
bias: bool = True, |
|
pad_type="zero", |
|
norm_type=None, |
|
act_type="leakyrelu", |
|
mode: ConvMode = "CNA", |
|
plus=False, |
|
c2x2=False, |
|
): |
|
super(ResidualDenseBlock_5C, self).__init__() |
|
|
|
## + |
|
self.conv1x1 = conv1x1(nf, gc) if plus else None |
|
## + |
|
|
|
self.conv1 = conv_block( |
|
nf, |
|
gc, |
|
kernel_size, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=norm_type, |
|
act_type=act_type, |
|
mode=mode, |
|
c2x2=c2x2, |
|
) |
|
self.conv2 = conv_block( |
|
nf + gc, |
|
gc, |
|
kernel_size, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=norm_type, |
|
act_type=act_type, |
|
mode=mode, |
|
c2x2=c2x2, |
|
) |
|
self.conv3 = conv_block( |
|
nf + 2 * gc, |
|
gc, |
|
kernel_size, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=norm_type, |
|
act_type=act_type, |
|
mode=mode, |
|
c2x2=c2x2, |
|
) |
|
self.conv4 = conv_block( |
|
nf + 3 * gc, |
|
gc, |
|
kernel_size, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=norm_type, |
|
act_type=act_type, |
|
mode=mode, |
|
c2x2=c2x2, |
|
) |
|
if mode == "CNA": |
|
last_act = None |
|
else: |
|
last_act = act_type |
|
self.conv5 = conv_block( |
|
nf + 4 * gc, |
|
nf, |
|
3, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=norm_type, |
|
act_type=last_act, |
|
mode=mode, |
|
c2x2=c2x2, |
|
) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.conv2(torch.cat((x, x1), 1)) |
|
if self.conv1x1: |
|
# pylint: disable=not-callable |
|
x2 = x2 + self.conv1x1(x) # + |
|
x3 = self.conv3(torch.cat((x, x1, x2), 1)) |
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) |
|
if self.conv1x1: |
|
x4 = x4 + x2 # + |
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
|
return x5 * 0.2 + x |
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1): |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
|
|
|
#################### |
|
# Upsampler |
|
#################### |
|
|
|
|
|
def pixelshuffle_block( |
|
in_nc: int, |
|
out_nc: int, |
|
upscale_factor=2, |
|
kernel_size=3, |
|
stride=1, |
|
bias=True, |
|
pad_type="zero", |
|
norm_type: str | None = None, |
|
act_type="relu", |
|
): |
|
""" |
|
Pixel shuffle layer |
|
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional |
|
Neural Network, CVPR17) |
|
""" |
|
conv = conv_block( |
|
in_nc, |
|
out_nc * (upscale_factor**2), |
|
kernel_size, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=None, |
|
act_type=None, |
|
) |
|
pixel_shuffle = nn.PixelShuffle(upscale_factor) |
|
|
|
n = norm(norm_type, out_nc) if norm_type else None |
|
a = act(act_type) if act_type else None |
|
return sequential(conv, pixel_shuffle, n, a) |
|
|
|
|
|
def upconv_block( |
|
in_nc: int, |
|
out_nc: int, |
|
upscale_factor=2, |
|
kernel_size=3, |
|
stride=1, |
|
bias=True, |
|
pad_type="zero", |
|
norm_type: str | None = None, |
|
act_type="relu", |
|
mode="nearest", |
|
c2x2=False, |
|
): |
|
# Up conv |
|
# described in https://distill.pub/2016/deconv-checkerboard/ |
|
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) |
|
conv = conv_block( |
|
in_nc, |
|
out_nc, |
|
kernel_size, |
|
stride, |
|
bias=bias, |
|
pad_type=pad_type, |
|
norm_type=norm_type, |
|
act_type=act_type, |
|
c2x2=c2x2, |
|
) |
|
return sequential(upsample, conv)
|
|
|